Unverified Commit faefeef9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_shape_update

parents 97a40ac3 bf0a4713
...@@ -8,9 +8,9 @@ namespace migraphx { ...@@ -8,9 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void sync_device::apply(module& p) const void sync_device::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
if(last->name() == "@return") if(last->name() == "@return")
{ {
auto inputs = last->inputs(); auto inputs = last->inputs();
...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const ...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return (i->name() == "hip::copy_from_gpu"); return (i->name() == "hip::copy_from_gpu");
})) }))
{ {
auto sync_in = p.insert_instruction(last, make_op("hip::sync_stream"), inputs); auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs);
if(not inputs.empty()) if(not inputs.empty())
{ {
p.replace_instruction(inputs.front(), sync_in); m.replace_instruction(inputs.front(), sync_in);
} }
} }
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp> #include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
......
...@@ -11,25 +11,25 @@ namespace gpu { ...@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(module& p) const void write_literals::apply(module& m) const
{ {
assert(ctx != nullptr); assert(ctx != nullptr);
std::size_t n = 0; std::size_t n = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
if(enabled(MIGRAPHX_COPY_LITERALS{})) if(enabled(MIGRAPHX_COPY_LITERALS{}))
{ {
literal l = ins->get_literal(); literal l = ins->get_literal();
auto pre = p.add_literal(l); auto pre = m.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc); m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
} }
else else
{ {
std::string id = p.name() + ":@literal:" + std::to_string(n); std::string id = m.name() + ":@literal:" + std::to_string(n);
p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id}); m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
n++; n++;
} }
} }
......
...@@ -3,23 +3,21 @@ ...@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(add_op) TEST_CASE(add_literals)
{ {
migraphx::program p; migraphx::program p;
migraphx::module m = p.get_main_module(); migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}}; migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape); std::vector<float> x_values(9, 1);
auto y = m.add_parameter("y", param_shape); auto x = m.add_literal(param_shape, x_values.data());
std::vector<float> y_values(9, -1);
auto y = m.add_literal(param_shape, y_values.data());
auto add_op = migraphx::operation("add"); auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y}); auto r = m.add_instruction(add_op, {x, y});
m.add_return({r}); m.add_return({r});
// run on ref target // run on ref target
p.compile(migraphx::target("ref")); p.compile(migraphx::target("ref"));
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
std::vector<float> expected(9, 0); std::vector<float> expected(9, 0);
......
...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3) ...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module) TEST_CASE(unused_module)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1) ...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y"))); match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any2) TEST_CASE(match_either_args_any2)
...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2) ...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any3) TEST_CASE(match_either_args_any3)
...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3) ...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any4) TEST_CASE(match_either_args_any4)
...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4) ...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any5) TEST_CASE(match_either_args_any5)
...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5) ...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1) ...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one}); EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2) ...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(*child, m); auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
......
...@@ -3178,6 +3178,20 @@ def mean_test(): ...@@ -3178,6 +3178,20 @@ def mean_test():
return ([node], data, [mean]) return ([node], data, [mean])
@onnx_test
def mean_integral_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.INT32, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.INT32, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test @onnx_test
def min_test(): def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment