Unverified Commit 36c4d147 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into bert_ops

parents f0b4e8eb 55182aac
...@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DRIVER_STATIC static #define MIGRAPHX_DRIVER_STATIC static
#endif #endif
template <class T>
using bare = std::remove_cv_t<std::remove_reference_t<T>>;
namespace detail {
template <class T>
auto is_container(int, T&& x) -> decltype(x.insert(x.end(), *x.begin()), std::true_type{});
template <class T>
std::false_type is_container(float, T&&);
} // namespace detail
template <class T>
struct is_container : decltype(detail::is_container(int(0), std::declval<T>()))
{
};
template <class T>
using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
template <class T> template <class T>
struct value_parser struct value_parser
{ {
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{})> template <MIGRAPHX_REQUIRES(not std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x) static T apply(const std::string& x)
{ {
T result; T result;
...@@ -43,7 +65,7 @@ struct value_parser ...@@ -43,7 +65,7 @@ struct value_parser
return result; return result;
} }
template <MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <MIGRAPHX_REQUIRES(std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x) static T apply(const std::string& x)
{ {
std::ptrdiff_t i; std::ptrdiff_t i;
...@@ -54,6 +76,15 @@ struct value_parser ...@@ -54,6 +76,15 @@ struct value_parser
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse: " + x);
return static_cast<T>(i); return static_cast<T>(i);
} }
template <MIGRAPHX_REQUIRES(is_multi_value<T>{} and not std::is_enum<T>{})>
static T apply(const std::string& x)
{
T result;
using value_type = typename T::value_type;
result.insert(result.end(), value_parser<value_type>::apply(x));
return result;
}
}; };
struct argument_parser struct argument_parser
...@@ -69,6 +100,18 @@ struct argument_parser ...@@ -69,6 +100,18 @@ struct argument_parser
unsigned nargs = 1; unsigned nargs = 1;
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string_range(x);
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
}
template <class T, class... Fs> template <class T, class... Fs>
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs) void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
{ {
...@@ -81,7 +124,7 @@ struct argument_parser ...@@ -81,7 +124,7 @@ struct argument_parser
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = migraphx::get_type_name<T>();
arg.default_value = to_string(x); arg.default_value = as_string_value(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
} }
...@@ -127,7 +170,7 @@ struct argument_parser ...@@ -127,7 +170,7 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto append() MIGRAPHX_DRIVER_STATIC auto append()
{ {
return write_action([](auto&, auto& x, auto& params) { return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type; using type = typename bare<decltype(params)>::value_type;
std::transform(params.begin(), std::transform(params.begin(),
params.end(), params.end(),
std::inserter(x, x.end()), std::inserter(x, x.end()),
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
...@@ -80,11 +81,13 @@ struct compiler ...@@ -80,11 +81,13 @@ struct compiler
{ {
loader l; loader l;
bool gpu = true; bool gpu = true;
std::vector<std::string> fill1;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true)); ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false)); ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append());
} }
program compile() program compile()
...@@ -94,7 +97,14 @@ struct compiler ...@@ -94,7 +97,14 @@ struct compiler
return p; return p;
} }
auto params(const program& p) { return create_param_map(p, gpu); } auto params(const program& p)
{
program::parameter_map m;
for(auto&& s : fill1)
m[s] = fill_argument(p.get_parameter_shape(s), 1);
fill_param_map(m, p, gpu);
return m;
}
}; };
struct read : command<read> struct read : command<read>
...@@ -109,6 +119,19 @@ struct read : command<read> ...@@ -109,6 +119,19 @@ struct read : command<read>
} }
}; };
struct params : command<params>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
void run()
{
auto p = l.load();
for(auto&& param : p.get_parameter_shapes())
std::cout << param.first << ": " << param.second << std::endl;
}
};
struct verify : command<verify> struct verify : command<verify>
{ {
loader l; loader l;
......
...@@ -11,6 +11,23 @@ namespace migraphx { ...@@ -11,6 +11,23 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu)
{
for(auto&& x : p.get_parameter_shapes())
{
argument& arg = m[x.first];
if(arg.empty())
arg = generate_argument(x.second);
#ifdef HAVE_GPU
if(gpu)
arg = gpu::to_gpu(arg);
#else
(void)gpu;
#endif
}
return m;
}
program::parameter_map create_param_map(const program& p, bool gpu) program::parameter_map create_param_map(const program& p, bool gpu)
{ {
program::parameter_map m; program::parameter_map m;
......
...@@ -7,6 +7,7 @@ namespace migraphx { ...@@ -7,6 +7,7 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu);
program::parameter_map create_param_map(const program& p, bool gpu = true); program::parameter_map create_param_map(const program& p, bool gpu = true);
void compile_program(program& p, bool gpu = true); void compile_program(program& p, bool gpu = true);
......
...@@ -3,6 +3,17 @@ ...@@ -3,6 +3,17 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value)
{
argument result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = fill_tensor_data<type>(s, value);
result = {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
});
return result;
}
argument generate_argument(shape s, unsigned long seed) argument generate_argument(shape s, unsigned long seed)
{ {
argument result; argument result;
......
...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return result; return result;
} }
template <class T>
std::vector<T> fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), [=] { return value; });
return result;
}
argument fill_argument(shape s, unsigned long value = 0);
argument generate_argument(shape s, unsigned long seed = 0); argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0); literal generate_literal(shape s, unsigned long seed = 0);
......
...@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>; ...@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
bool MIGRAPHX_REQUIRES_VAR() = true, \ long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0 int>::type = 0
#endif #endif
......
...@@ -206,6 +206,16 @@ struct onnx_parser ...@@ -206,6 +206,16 @@ struct onnx_parser
return out_lens; return out_lens;
} }
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T> template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{ {
...@@ -437,12 +447,7 @@ struct onnx_parser ...@@ -437,12 +447,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(!args[0]->get_shape().standard()) return prog.add_instruction(op, make_contiguous(args[0]));
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]);
} }
instruction_ref instruction_ref
...@@ -490,8 +495,9 @@ struct onnx_parser ...@@ -490,8 +495,9 @@ struct onnx_parser
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
op::gather op{axis}; op::gather op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
} }
instruction_ref instruction_ref
......
...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
instruction_ref input_fp16{}; instruction_ref input_fp16{};
if(input->name() == "convert") if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{ {
input_fp16 = input->inputs().front(); input_fp16 = input->inputs().front();
} }
......
...@@ -8,51 +8,6 @@ namespace migraphx { ...@@ -8,51 +8,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = as(op.alpha);
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0); assert(k % 4 == 0);
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0, 0,
nullptr, nullptr,
nullptr); nullptr);
} }
else else
{ {
generic_rocblas_batched_gemm_ex( rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -509,6 +510,32 @@ TEST_CASE(shape_gather_test) ...@@ -509,6 +510,32 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment