Commit 1005a693 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 0369e974 636bce89
......@@ -12,6 +12,8 @@ using index_int = std::uint32_t;
template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
} // namespace migraphx
#endif
......@@ -183,6 +183,8 @@ struct miopen_apply
add_extend_op("softmax");
add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op();
add_convolution_op();
add_deconvolution_op();
......@@ -381,6 +383,21 @@ struct miopen_apply
});
}
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op()
{
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
......
......@@ -9,6 +9,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp>
......@@ -25,6 +26,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp>
......@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION)
struct id_pass
{
std::string name() const { return "id"; }
void apple(const module&) const {}
};
pass enable_pass(bool enabled, pass p)
{
if(enabled)
return p;
return id_pass{};
}
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
......@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
mlir_conv{&ctx},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
......@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math},
dead_code_elimination{},
compile_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <pointwise.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
......@@ -159,4 +161,25 @@ TEST_CASE(standard_flatten_op)
EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
}
TEST_CASE(contiguous_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 8, 8}};
migraphx::program p;
auto* mm = p.get_main_module();
{
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}});
auto yb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add"));
mm->add_instruction(pass_op{}, add);
}
auto count = std::distance(mm->begin(), mm->end());
run_pass(*mm);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
EXPECT(std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -7,38 +7,13 @@
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
......@@ -60,9 +35,9 @@ TEST_CASE(single)
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, z}, single_pointwise("add"));
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
......@@ -84,14 +59,15 @@ TEST_CASE(double_add)
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
......@@ -117,10 +93,10 @@ TEST_CASE(used_twice_not_fused)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd =
add_pointwise(p2, "pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto fadd = add_pointwise(
p2, "main:pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
});
......@@ -149,7 +125,7 @@ TEST_CASE(used_twice_fused)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
......@@ -179,11 +155,11 @@ TEST_CASE(duplicate_inputs)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
});
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, y}, single_pointwise("add"));
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1.sort() == p2.sort());
......@@ -207,7 +183,35 @@ TEST_CASE(scalar_input)
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
TEST_CASE(contiguous_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one);
auto y = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
......@@ -216,4 +220,32 @@ TEST_CASE(scalar_input)
EXPECT(p1 == p2);
}
TEST_CASE(all_scalar_input)
{
migraphx::shape s{migraphx::shape::float_type};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
});
mm->add_return({add1});
}
EXPECT(p1.get_output_shapes().size() == 1);
EXPECT(p1.get_output_shapes().front().scalar());
EXPECT(p1.get_output_shapes() == p2.get_output_shapes());
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
File mode changed from 100644 to 100755
#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#include <migraphx/program.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
#endif // MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#include <migraphx/stringutils.hpp>
#include <test.hpp>
TEST_CASE(interpolate_string_simple1)
{
std::string input = "Hello ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_simple2)
{
std::string input = "${hello}";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_unbalanced)
{
std::string input = "${hello";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"hello", "bye"}}); }));
}
TEST_CASE(interpolate_string_extra_space)
{
std::string input = "${ hello }";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_multiple)
{
std::string input = "${h} ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}, {"h", "Hello"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_next)
{
std::string input = "${hh}${ww}!";
auto s = migraphx::interpolate_string(input, {{"ww", "world"}, {"hh", "Hello"}});
EXPECT(s == "Helloworld!");
}
TEST_CASE(interpolate_string_dollar_sign)
{
std::string input = "$hello";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "$hello");
}
TEST_CASE(interpolate_string_missing)
{
std::string input = "${hello}";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"h", "bye"}}); }));
}
TEST_CASE(interpolate_string_custom1)
{
std::string input = "****{{a}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{", "}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom2)
{
std::string input = "****{{{a}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{", "}}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom3)
{
std::string input = "****{{{{a}}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{{", "}}}}");
EXPECT(s == "****b****");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -45,14 +45,6 @@ int main(int argc, const char* argv[])
run_verify rv;
rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu",
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv);
}
......@@ -2,44 +2,44 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/make_op.hpp>
// struct test_conv_bn_add : verify_program<test_conv_bn_add>
// {
// static migraphx::instruction_ref add_bn(migraphx::program& p,
// migraphx::instruction_ref x,
// std::size_t channels,
// std::size_t seed = 1)
// {
// migraphx::shape vars{migraphx::shape::float_type, {channels}};
// auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 +
// seed))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2
// + seed))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars,
// 3 + seed))); auto variance =
// mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); return
// mm->add_instruction(
// migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
// }
struct test_conv_bn_add : verify_program<test_conv_bn_add>
{
static migraphx::instruction_ref add_bn(migraphx::module& m,
migraphx::instruction_ref x,
std::size_t channels,
std::size_t seed = 1)
{
migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed)));
return m.add_instruction(
migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance);
}
// migraphx::program create_program() const
// {
// migraphx::program p;
// std::size_t ichannels = 64;
// std::size_t ochannels = 256;
// auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56,
// 56}}); auto w = mm->add_literal(migraphx::generate_literal(
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
// auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56,
// 56}}); auto v = mm->add_literal(migraphx::generate_literal(
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
// auto relu1 = mm->add_instruction(migraphx::op::relu{}, x);
// auto conv1 = mm->add_instruction(migraphx::op::convolution{}, relu1, w);
// auto bn1 = add_bn(p, conv1, ochannels, 1);
// auto relu2 = mm->add_instruction(migraphx::op::relu{}, y);
// auto conv2 = mm->add_instruction(migraphx::op::convolution{}, relu2, v);
// auto bn2 = add_bn(p, conv2, ochannels, 1);
// auto sum = mm->add_instruction(migraphx::op::add{}, bn1, bn2);
// mm->add_instruction(migraphx::op::relu{}, sum);
// return p;
// }
// };
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::size_t ichannels = 64;
std::size_t ochannels = 256;
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto w = mm->add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto v = mm->add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
auto bn1 = add_bn(*mm, conv1, ochannels, 1);
auto relu2 = mm->add_instruction(migraphx::make_op("relu"), y);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), relu2, v);
auto bn2 = add_bn(*mm, conv2, ochannels, 1);
auto sum = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
mm->add_instruction(migraphx::make_op("relu"), sum);
return p;
}
};
......@@ -103,79 +103,69 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs))
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
return x.compute_shape(inputs);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, mod_args);
return x.compute_shape(inputs, {});
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
return compute_shape_op(rank<3>{}, x, inputs);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
shape mod_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
if(mod_args.empty())
return compute_shape_op(x, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
shape mod_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
......@@ -488,13 +478,13 @@ lifetime get_lifetime_op(const T&)
returns = 'shape',
input = 'const std::vector<shape>&',
const = True,
default = 'detail::normalize_compute_shape_op'),
default = 'detail::compute_shape_op'),
virtual('compute_shape',
returns = 'shape',
inputs = 'const std::vector<shape>&',
mod_args = 'const std::vector<module_ref>&',
const = True,
default = 'detail::compute_shape_op'),
default = 'detail::mod_compute_shape_op'),
virtual('compute',
returns = 'argument',
ctx = 'context&',
......@@ -582,7 +572,7 @@ template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs))
{
return detail::normalize_compute_shape_op(op, inputs);
return detail::compute_shape_op(op, inputs);
}
inline shape compute_shape(const operation& op,
......@@ -607,7 +597,7 @@ inline auto compute_shape(const T& op,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
return detail::compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment