Commit cced8ba3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

improve the scenario of eliminate contiguous.

parent 2a3042b1
...@@ -21,9 +21,14 @@ struct binary ...@@ -21,9 +21,14 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type(); if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
auto lens = inputs.at(0).lens(); {
return {t, lens}; return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
}; };
......
...@@ -29,7 +29,7 @@ struct logsoftmax ...@@ -29,7 +29,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis > inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
...@@ -21,8 +21,15 @@ struct unary ...@@ -21,8 +21,15 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
if (inputs.front().packed())
{
return inputs.at(0); return inputs.at(0);
} }
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
}; };
} // namespace op } // namespace op
......
...@@ -593,13 +593,33 @@ struct cpu_unary ...@@ -593,13 +593,33 @@ struct cpu_unary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
if (inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn()); std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
}); });
}); });
...@@ -773,12 +793,25 @@ struct cpu_binary ...@@ -773,12 +793,25 @@ struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed()) auto s1 = input1.get_shape();
auto s2 = input2.get_shape();
if(s1 == s2 and s1.packed() and s2.packed())
{ {
std::transform( std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn()); input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
...@@ -791,6 +824,7 @@ struct cpu_binary ...@@ -791,6 +824,7 @@ struct cpu_binary
}); });
} }
}); });
return result; return result;
} }
}; };
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0); return inputs.at(0);
} }
......
...@@ -45,7 +45,14 @@ struct unary_device : oper<Derived> ...@@ -45,7 +45,14 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return inputs.at(1); if (inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -63,7 +70,14 @@ struct binary_device : oper<Derived> ...@@ -63,7 +70,14 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return inputs.at(2); if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
...@@ -7,7 +7,7 @@ namespace gpu { ...@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0); return inputs.at(0);
} }
......
...@@ -335,7 +335,9 @@ struct test_trans_tanh : verify_program<test_trans_tanh> ...@@ -335,7 +335,9 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx); auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p; return p;
} }
}; };
...@@ -694,8 +696,10 @@ struct test_trans_abs : verify_program<test_trans_abs> ...@@ -694,8 +696,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::abs{}, tx); auto absx = p.add_instruction(migraphx::op::abs{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p; return p;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment