Commit b7b7314e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into keep_std_shape

parents a98a53f5 c7419a9c
...@@ -15,7 +15,7 @@ p = parse_onnx(input_file, options); ...@@ -15,7 +15,7 @@ p = parse_onnx(input_file, options);
``` ```
## Saving ## Saving
An instantiated migraphx::program object can then be serialized to MessagePack (.msgpack) format and saved so that it can be loaded for future uses. An instantiated migraphx::program object can then be serialized to MessagePack (.mxr) format and saved so that it can be loaded for future uses.
A program can be saved with either of the following: A program can be saved with either of the following:
``` ```
......
...@@ -77,7 +77,7 @@ int main(int argc, char** argv) ...@@ -77,7 +77,7 @@ int main(int argc, char** argv)
std::cout << "Saving program..." << std::endl; std::cout << "Saving program..." << std::endl;
std::string output_file; std::string output_file;
output_file = save_arg == nullptr ? "out" : save_arg; output_file = save_arg == nullptr ? "out" : save_arg;
output_file.append(".msgpack"); output_file.append(".mxr");
migraphx::file_options options; migraphx::file_options options;
options.set_file_format("msgpack"); options.set_file_format("msgpack");
......
...@@ -50,10 +50,10 @@ ...@@ -50,10 +50,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"if not os.path.exists(\"yolov4_fp16.msgpack\"):\n", "if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.msgpack\n", " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.msgpack\"):\n", "if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.msgpack" " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
] ]
}, },
{ {
...@@ -115,8 +115,8 @@ ...@@ -115,8 +115,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load serialized model (either single- or half-precision)\n", "# Load serialized model (either single- or half-precision)\n",
"model = migraphx.load(\"yolov4.msgpack\", format=\"msgpack\")\n", "model = migraphx.load(\"yolov4.mxr\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.msgpack\", format=\"msgpack\")\n", "#model = migraphx.load(\"yolov4_fp16.mxr\", format=\"msgpack\")\n",
"\n", "\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n", "# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"input_name = next(iter(model.get_parameter_shapes()))\n", "input_name = next(iter(model.get_parameter_shapes()))\n",
......
...@@ -37,43 +37,49 @@ struct squeeze ...@@ -37,43 +37,49 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
......
...@@ -6,15 +6,32 @@ ...@@ -6,15 +6,32 @@
namespace migraphx { namespace migraphx {
template <class T>
struct remove_vec_impl
{
using type = T;
};
template <class T, index_int N>
struct remove_vec_impl<vec<T, N>>
{
using type = T;
};
template <class T>
using remove_vec = typename remove_vec_impl<T>::type;
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss) constexpr auto traverse_preload(Shapes... ss)
{ {
return [=](auto f, auto... g) { return [=](auto f, auto... g) {
index_int offset = 0; index_int offset = 0;
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted() or (s.elements() - size) < 64) if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
invoke); invoke);
} }
template <class T> template <class T, class Shape>
struct remove_vec struct shape_type : Shape
{ {
using type = T; using type = T;
}; };
template <class T, index_int N> template <class T>
struct remove_vec<vec<T, N>> constexpr auto make_shape_type(T)
{ {
using type = T; return shape_type<typename T::type, typename T::shape_type>{};
}; }
template <class T, class... Ts> template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = remove_vec<T>;
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){}; constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -10,6 +10,7 @@ template <class T, class Shape> ...@@ -10,6 +10,7 @@ template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr index_int size() const { return get_shape().elements(); }
......
...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)> ...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)>
{ {
}; };
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -44,7 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -44,7 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
struct id_pass struct id_pass
{ {
...@@ -100,7 +100,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -100,7 +100,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
mlir_conv{&ctx}, mlir_conv{&ctx},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
TEST_CASE(load_save_default) TEST_CASE(load_save_default)
{ {
std::string filename = "migraphx_api_load_save.dat"; std::string filename = "migraphx_api_load_save.mxr";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes(); auto s1 = p1.get_output_shapes();
......
#include <test.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/instruction.hpp>
migraphx::program create_gelu()
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3};
std::vector<float> data3 = {0.5};
migraphx::shape s0{migraphx::shape::float_type, {1}};
std::vector<size_t> x_dims{1, 1, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = mm->add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = mm->add_literal(migraphx::literal{s0, data1});
auto three_val = mm->add_literal(migraphx::literal{s0, data2});
auto half_val = mm->add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = mm->add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = mm->add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = mm->add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = mm->add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = mm->add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = mm->add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = mm->add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = mm->add_instruction(migraphx::op::mul{}, x, add_mul_half);
mm->add_return({mul_x});
return p;
}
TEST_CASE(enable_fast_gelu)
{
migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{});
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu"; }));
}
TEST_CASE(disable_fast_gelu)
{
migraphx::program p = create_gelu();
migraphx::compile_options options;
options.fast_math = false;
p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1446,6 +1446,27 @@ TEST_CASE(test_squeeze_all) ...@@ -1446,6 +1446,27 @@ TEST_CASE(test_squeeze_all)
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
} }
TEST_CASE(test_squeeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4}, {4, 1}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_multibroadcast)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_slice)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 6, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_squeeze_negative_axis) TEST_CASE(test_squeeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -13,11 +13,7 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -13,11 +13,7 @@ TEST_CASE(argmax_test_nonstd_shape)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans = auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans); mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
...@@ -36,11 +32,7 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -36,11 +32,7 @@ TEST_CASE(argmin_test_nonstd_shape)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans = auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans); mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
...@@ -55,4 +47,62 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -55,4 +47,62 @@ TEST_CASE(argmin_test_nonstd_shape)
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(squeeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 1, 3, 1, 3}}));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
// contiguous is required to read the values in standard shaped order
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}});
EXPECT(result == std_expected_result);
}
TEST_CASE(squeeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 1, 3}}));
auto l0_brcst = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}});
EXPECT(result == std_expected_result);
}
TEST_CASE(squeeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 4, 3}}));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_slice);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
EXPECT(result == std_expected_result);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -48,7 +48,7 @@ TEST_CASE(as_json) ...@@ -48,7 +48,7 @@ TEST_CASE(as_json)
TEST_CASE(as_file) TEST_CASE(as_file)
{ {
std::string filename = "migraphx_program.dat"; std::string filename = "migraphx_program.mxr";
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::save(p1, filename); migraphx::save(p1, filename);
migraphx::program p2 = migraphx::load(filename); migraphx::program p2 = migraphx::load(filename);
......
...@@ -128,7 +128,7 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -128,7 +128,7 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
auto_print::set_terminate_handler(name); auto_print::set_terminate_handler(name);
if(migraphx::enabled(MIGRAPHX_DUMP_TEST{})) if(migraphx::enabled(MIGRAPHX_DUMP_TEST{}))
migraphx::save(p, name + ".mx"); migraphx::save(p, name + ".mxr");
std::vector<std::pair<std::string, result_future>> results; std::vector<std::pair<std::string, result_future>> results;
std::vector<std::string> target_names; std::vector<std::string> target_names;
for(const auto& tname : migraphx::get_targets()) for(const auto& tname : migraphx::get_targets())
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_sub_int : verify_program<test_sub_int>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::int16_type, {4, 5}});
auto y = mm->add_parameter("y", {migraphx::shape::int16_type, {2, 3, 4, 5}});
auto xb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), x);
auto diff = mm->add_instruction(migraphx::make_op("sub"), y, xb);
mm->add_return({diff});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment