"vscode:/vscode.git/clone" did not exist on "fce16b00f4ce223b31ca7b203ee3298b8459f0a5"
Commit 2d9e620b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into...

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into test_runner_match_input_output
parents 2a73d9a9 19f65e7e
...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); auto t = transform_args(make_tensors(), rotate_last());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -12,6 +12,8 @@ using index_int = std::uint32_t; ...@@ -12,6 +12,8 @@ using index_int = std::uint32_t;
template <class T, index_int N> template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N))); using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -183,6 +183,8 @@ struct miopen_apply ...@@ -183,6 +183,8 @@ struct miopen_apply
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
add_convolution_op(); add_convolution_op();
add_deconvolution_op(); add_deconvolution_op();
...@@ -381,6 +383,21 @@ struct miopen_apply ...@@ -381,6 +383,21 @@ struct miopen_apply
}); });
} }
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op() void add_batch_norm_inference_op()
{ {
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) { apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -25,6 +26,7 @@ ...@@ -25,6 +26,7 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
...@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION)
struct id_pass
{
std::string name() const { return "id"; }
void apple(const module&) const {}
};
pass enable_pass(bool enabled, pass p)
{
if(enabled)
return p;
return id_pass{};
}
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{ {
...@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
mlir_conv{&ctx}, mlir_conv{&ctx},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
...@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
compile_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
......
...@@ -60,9 +60,9 @@ TEST_CASE(single) ...@@ -60,9 +60,9 @@ TEST_CASE(single)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add")); auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1); auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, z}, single_pointwise("add")); auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2}); mm->add_return({add2});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -84,14 +84,15 @@ TEST_CASE(double_add) ...@@ -84,14 +84,15 @@ TEST_CASE(double_add)
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { auto fadd =
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
}); return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd}); mm->add_return({fadd});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
...@@ -117,10 +118,10 @@ TEST_CASE(used_twice_not_fused) ...@@ -117,10 +118,10 @@ TEST_CASE(used_twice_not_fused)
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add")); auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1); auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd = auto fadd = add_pointwise(
add_pointwise(p2, "pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) { p2, "main:pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2); return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
}); });
...@@ -149,7 +150,7 @@ TEST_CASE(used_twice_fused) ...@@ -149,7 +150,7 @@ TEST_CASE(used_twice_fused)
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]); auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
...@@ -179,11 +180,11 @@ TEST_CASE(duplicate_inputs) ...@@ -179,11 +180,11 @@ TEST_CASE(duplicate_inputs)
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) { auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]); return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
}); });
auto pass = mm->add_instruction(pass_op{}, add1); auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, y}, single_pointwise("add")); auto add2 = add_pointwise(p2, "main:pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2}); mm->add_return({add2});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
...@@ -207,7 +208,35 @@ TEST_CASE(scalar_input) ...@@ -207,7 +208,35 @@ TEST_CASE(scalar_input)
{ {
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) { auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
TEST_CASE(contiguous_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one);
auto y = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f); auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y); return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
}); });
...@@ -216,4 +245,32 @@ TEST_CASE(scalar_input) ...@@ -216,4 +245,32 @@ TEST_CASE(scalar_input)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(all_scalar_input)
{
migraphx::shape s{migraphx::shape::float_type};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
});
mm->add_return({add1});
}
EXPECT(p1.get_output_shapes().size() == 1);
EXPECT(p1.get_output_shapes().front().scalar());
EXPECT(p1.get_output_shapes() == p2.get_output_shapes());
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -19,6 +19,7 @@ TEST_CASE(perf_report) ...@@ -19,6 +19,7 @@ TEST_CASE(perf_report)
std::string output = ss.str(); std::string output = ss.str();
EXPECT(migraphx::contains(output, "Summary:")); EXPECT(migraphx::contains(output, "Summary:"));
EXPECT(migraphx::contains(output, "Batch size:"));
EXPECT(migraphx::contains(output, "Rate:")); EXPECT(migraphx::contains(output, "Rate:"));
EXPECT(migraphx::contains(output, "Total time:")); EXPECT(migraphx::contains(output, "Total time:"));
EXPECT(migraphx::contains(output, "Total instructions time:")); EXPECT(migraphx::contains(output, "Total instructions time:"));
......
#include <migraphx/stringutils.hpp>
#include <test.hpp>
TEST_CASE(interpolate_string_simple1)
{
std::string input = "Hello ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_simple2)
{
std::string input = "${hello}";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_unbalanced)
{
std::string input = "${hello";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"hello", "bye"}}); }));
}
TEST_CASE(interpolate_string_extra_space)
{
std::string input = "${ hello }";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_multiple)
{
std::string input = "${h} ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}, {"h", "Hello"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_next)
{
std::string input = "${hh}${ww}!";
auto s = migraphx::interpolate_string(input, {{"ww", "world"}, {"hh", "Hello"}});
EXPECT(s == "Helloworld!");
}
TEST_CASE(interpolate_string_dollar_sign)
{
std::string input = "$hello";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "$hello");
}
TEST_CASE(interpolate_string_missing)
{
std::string input = "${hello}";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"h", "bye"}}); }));
}
TEST_CASE(interpolate_string_custom1)
{
std::string input = "****{{a}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{", "}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom2)
{
std::string input = "****{{{a}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{", "}}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom3)
{
std::string input = "****{{{{a}}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{{", "}}}}");
EXPECT(s == "****b****");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -45,14 +45,6 @@ int main(int argc, const char* argv[]) ...@@ -45,14 +45,6 @@ int main(int argc, const char* argv[])
run_verify rv; run_verify rv;
rv.add_validation_for("gpu", &validate_gpu); rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"}); rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu", rv.disable_test_for("gpu", {"test_conv_bn_add"});
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -2,44 +2,44 @@ ...@@ -2,44 +2,44 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/make_op.hpp>
// struct test_conv_bn_add : verify_program<test_conv_bn_add> struct test_conv_bn_add : verify_program<test_conv_bn_add>
// { {
// static migraphx::instruction_ref add_bn(migraphx::program& p, static migraphx::instruction_ref add_bn(migraphx::module& m,
// migraphx::instruction_ref x, migraphx::instruction_ref x,
// std::size_t channels, std::size_t channels,
// std::size_t seed = 1) std::size_t seed = 1)
// { {
// migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
// auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
// seed))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed)));
// + seed))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed)));
// 3 + seed))); auto variance = auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed)));
// mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); return return m.add_instruction(
// mm->add_instruction( migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance);
// migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); }
// }
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// std::size_t ichannels = 64; auto* mm = p.get_main_module();
// std::size_t ochannels = 256; std::size_t ichannels = 64;
// auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, std::size_t ochannels = 256;
// 56}}); auto w = mm->add_literal(migraphx::generate_literal( auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); auto w = mm->add_literal(migraphx::generate_literal(
// auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
// 56}}); auto v = mm->add_literal(migraphx::generate_literal( auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); auto v = mm->add_literal(migraphx::generate_literal(
// auto relu1 = mm->add_instruction(migraphx::op::relu{}, x); {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
// auto conv1 = mm->add_instruction(migraphx::op::convolution{}, relu1, w); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
// auto bn1 = add_bn(p, conv1, ochannels, 1); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
// auto relu2 = mm->add_instruction(migraphx::op::relu{}, y); auto bn1 = add_bn(*mm, conv1, ochannels, 1);
// auto conv2 = mm->add_instruction(migraphx::op::convolution{}, relu2, v); auto relu2 = mm->add_instruction(migraphx::make_op("relu"), y);
// auto bn2 = add_bn(p, conv2, ochannels, 1); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), relu2, v);
// auto sum = mm->add_instruction(migraphx::op::add{}, bn1, bn2); auto bn2 = add_bn(*mm, conv2, ochannels, 1);
// mm->add_instruction(migraphx::op::relu{}, sum); auto sum = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
// return p; mm->add_instruction(migraphx::make_op("relu"), sum);
// } return p;
// }; }
};
...@@ -103,7 +103,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -103,7 +103,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs) auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
...@@ -111,6 +111,13 @@ auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& i ...@@ -111,6 +111,13 @@ auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& i
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, {});
}
template <class T> template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{ {
...@@ -121,7 +128,7 @@ shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) ...@@ -121,7 +128,7 @@ shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs) shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{ {
return normalize_compute_shape_op(rank<1>{}, x, inputs); return normalize_compute_shape_op(rank<2>{}, x, inputs);
} }
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment