"vscode:/vscode.git/clone" did not exist on "3107fda5d0d5c2ee41374466f3088d6cd93d4abf"
Commit cced8ba3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

improve the scenario of eliminate contiguous.

parent 2a3042b1
......@@ -21,9 +21,14 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
};
......
......@@ -29,7 +29,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
......@@ -21,8 +21,15 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if (inputs.front().packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
};
} // namespace op
......
......@@ -593,13 +593,33 @@ struct cpu_unary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
if (inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
});
});
......@@ -773,12 +793,25 @@ struct cpu_binary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed())
auto s1 = input1.get_shape();
auto s2 = input2.get_shape();
if(s1 == s2 and s1.packed() and s2.packed())
{
std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
......@@ -791,6 +824,7 @@ struct cpu_binary
});
}
});
return result;
}
};
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0);
}
......
......@@ -45,7 +45,14 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(1);
if (inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -63,7 +70,14 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return inputs.at(2);
if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0);
}
......
......@@ -335,7 +335,9 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
......@@ -694,8 +696,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::abs{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
auto absx = p.add_instruction(migraphx::op::abs{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment