Commit 5e24fdf9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into int8_miopen_call

parents 3934afa3 b606ed4f
...@@ -9,19 +9,60 @@ ...@@ -9,19 +9,60 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args) static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{ {
try try
{ {
compute_shape(op, args); shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
return true;
}
// if no changes for the shape, the contiguous can also be removed
if(new_shape == ins->get_shape())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means it is the last
// instruction and generates a non-standard output shape, and the last
// output shape is different from the case with the contiguous operator
if(outputs.empty())
{
return false;
}
for(auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes(args.size());
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes))
{
return false;
}
}
} }
catch(...) catch(...)
{ {
return false; return false;
} }
return true; return true;
} }
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -44,7 +85,7 @@ void eliminate_contiguous::apply(program& p) const ...@@ -44,7 +85,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args)) if(try_compute_shape(ins, new_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -13,10 +13,16 @@ struct binary : op_name<Derived> ...@@ -13,10 +13,16 @@ struct binary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
const auto& s = inputs.front(); auto s0 = inputs.at(0);
if(s.scalar() and s.elements() == 1) auto s1 = inputs.at(1);
return {s.type()}; if(s0 == s1 and s0.packed())
return {s.type(), s.lens()}; {
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -30,7 +30,7 @@ struct gather ...@@ -30,7 +30,7 @@ struct gather
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size()); int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim) if(axis >= n_dim || axis < -n_dim)
......
...@@ -29,7 +29,7 @@ struct logsoftmax ...@@ -29,7 +29,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis > inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
...@@ -13,7 +13,15 @@ struct unary : op_name<Derived> ...@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -697,13 +697,35 @@ struct cpu_unary ...@@ -697,13 +697,35 @@ struct cpu_unary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), op.fcn()); if(input.get_shape().standard())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
}); });
}); });
...@@ -877,12 +899,28 @@ struct cpu_binary ...@@ -877,12 +899,28 @@ struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return "cpu::" + op.name(); } std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed()) auto s1 = input1.get_shape();
auto s2 = input2.get_shape();
if(s1 == s2 and s1.standard())
{ {
std::transform( std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn()); input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
...@@ -895,6 +933,7 @@ struct cpu_binary ...@@ -895,6 +933,7 @@ struct cpu_binary
}); });
} }
}); });
return result; return result;
} }
}; };
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0); return inputs.at(0);
} }
......
...@@ -45,7 +45,15 @@ struct unary_device : oper<Derived> ...@@ -45,7 +45,15 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return inputs.at(1); auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -66,7 +74,16 @@ struct binary_device : oper<Derived> ...@@ -66,7 +74,16 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return inputs.at(2); auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0); return inputs.at(0);
} }
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op) ...@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op)
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
} }
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/op/tanh.hpp> #include <migraphx/op/tanh.hpp>
...@@ -37,7 +38,8 @@ TEST_CASE(tanh_shape) ...@@ -37,7 +38,8 @@ TEST_CASE(tanh_shape)
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx); auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, txh, txh); auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
return p; return p;
}; };
...@@ -55,8 +57,8 @@ TEST_CASE(tanh_shape) ...@@ -55,8 +57,8 @@ TEST_CASE(tanh_shape)
{ {
if(ins->name() == "hip::allocate") if(ins->name() == "hip::allocate")
{ {
migraphx::shape wrong_s{migraphx::shape::float_type, {3, 2}, {1, 3}}; migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
ins->replace(wrong_s); migraphx::instruction::replace(ins, ins->get_operator(), new_s, ins->inputs());
} }
} }
EXPECT(p1 != p2); EXPECT(p1 != p2);
......
...@@ -333,7 +333,22 @@ struct test_trans_tanh : verify_program<test_trans_tanh> ...@@ -333,7 +333,22 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx); auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(migraphx::op::sin{}, t);
return p; return p;
} }
}; };
...@@ -692,8 +707,10 @@ struct test_trans_abs : verify_program<test_trans_abs> ...@@ -692,8 +707,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::abs{}, tx); auto absx = p.add_instruction(migraphx::op::abs{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p; return p;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment