Unverified Commit faefeef9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_shape_update

parents 97a40ac3 bf0a4713
......@@ -8,9 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void sync_device::apply(module& p) const
void sync_device::apply(module& m) const
{
auto last = std::prev(p.end());
auto last = std::prev(m.end());
if(last->name() == "@return")
{
auto inputs = last->inputs();
......@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return (i->name() == "hip::copy_from_gpu");
}))
{
auto sync_in = p.insert_instruction(last, make_op("hip::sync_stream"), inputs);
auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs);
if(not inputs.empty())
{
p.replace_instruction(inputs.front(), sync_in);
m.replace_instruction(inputs.front(), sync_in);
}
}
}
......
......@@ -31,6 +31,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
......@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{},
simplify_reshapes{},
simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
simplify_reshapes{},
propagate_constant{},
......
......@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(module& p) const
void write_literals::apply(module& m) const
{
assert(ctx != nullptr);
std::size_t n = 0;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
if(enabled(MIGRAPHX_COPY_LITERALS{}))
{
literal l = ins->get_literal();
auto pre = p.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
auto pre = m.add_literal(l);
auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
}
else
{
std::string id = p.name() + ":@literal:" + std::to_string(n);
p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
std::string id = m.name() + ":@literal:" + std::to_string(n);
m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
n++;
}
}
......
......@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(add_op)
TEST_CASE(add_literals)
{
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape);
auto y = m.add_parameter("y", param_shape);
std::vector<float> x_values(9, 1);
auto x = m.add_literal(param_shape, x_values.data());
std::vector<float> y_values(9, -1);
auto y = m.add_literal(param_shape, y_values.data());
auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y});
m.add_return({r});
// run on ref target
p.compile(migraphx::target("ref"));
migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<float> expected(9, 0);
......
......@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module)
{
migraphx::program p;
......
......@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any2)
......@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any3)
......@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any4)
......@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any5)
......@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_all_of1)
......@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape())
.bind("pass");
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......
......@@ -3178,6 +3178,20 @@ def mean_test():
return ([node], data, [mean])
@onnx_test
def mean_integral_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.INT32, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.INT32, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test
def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment