Commit 6be5c8fe authored by Paul's avatar Paul
Browse files

Merger

parents 73e64f33 b606ed4f
......@@ -9,19 +9,60 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{
try
{
compute_shape(op, args);
shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
return true;
}
// if no changes for the shape, the contiguous can also be removed
if(new_shape == ins->get_shape())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means it is the last
// instruction and generates a non-standard output shape, and the last
// output shape is different from the case with the contiguous operator
if(outputs.empty())
{
return false;
}
for(auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes(args.size());
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes))
{
return false;
}
}
}
catch(...)
{
return false;
}
return true;
}
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
......@@ -44,7 +85,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args))
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -13,10 +13,16 @@ struct binary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
const auto& s = inputs.front();
if(s.scalar() and s.elements() == 1)
return {s.type()};
return {s.type(), s.lens()};
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......
......@@ -30,7 +30,7 @@ struct gather
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
......
......@@ -29,7 +29,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
......@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......
......@@ -481,13 +481,35 @@ struct cpu_unary
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
});
});
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0);
}
......
......@@ -171,10 +171,10 @@ auto reflect(miopenActivationDescriptor_t ad, F f)
double beta = 0.0;
double gamma = 0.0;
miopenGetActivationDescriptor(ad, &mode, &alpha, &beta, &gamma);
return pack(f(std::move(mode), "mode"),
f(std::move(alpha), "alpha"),
f(std::move(beta), "beta"),
f(std::move(gamma), "gamma"));
return pack(f(std::move(mode), "mode"), // NOLINT
f(std::move(alpha), "alpha"), // NOLINT
f(std::move(beta), "beta"), // NOLINT
f(std::move(gamma), "gamma")); // NOLINT
}
template <class F>
......@@ -187,11 +187,11 @@ auto reflect(miopenLRNDescriptor_t lrnd, F f)
double beta = 0.0;
double k = 0.0;
miopenGetLRNDescriptor(lrnd, &mode, &n, &alpha, &beta, &k);
return pack(f(std::move(mode), "mode"),
f(std::move(n), "n"),
f(std::move(alpha), "alpha"),
f(std::move(beta), "beta"),
f(std::move(k), "k"));
return pack(f(std::move(mode), "mode"), // NOLINT
f(std::move(n), "n"), // NOLINT
f(std::move(alpha), "alpha"), // NOLINT
f(std::move(beta), "beta"), // NOLINT
f(std::move(k), "k")); // NOLINT
}
} // namespace gpu
......
......@@ -45,7 +45,15 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(1);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -66,7 +74,16 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return inputs.at(2);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0);
}
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp>
......@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op)
p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/tanh.hpp>
......@@ -37,7 +38,8 @@ TEST_CASE(tanh_shape)
auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, txh, txh);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
return p;
};
......@@ -55,8 +57,8 @@ TEST_CASE(tanh_shape)
{
if(ins->name() == "hip::allocate")
{
migraphx::shape wrong_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
ins->replace(migraphx::gpu::hip_allocate{wrong_s});
migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
migraphx::instruction::replace(ins, ins->get_operator(), new_s, ins->inputs());
}
}
EXPECT(p1 != p2);
......
......@@ -333,7 +333,22 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(migraphx::op::sin{}, t);
return p;
}
};
......@@ -692,8 +707,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::abs{}, tx);
p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
auto absx = p.add_instruction(migraphx::op::abs{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment