Commit e29a2396 authored by wsttiger's avatar wsttiger
Browse files

Added and fixed up transpose and contiguous and test

parent c34bdf3b
...@@ -173,11 +173,31 @@ struct miopen_gemm ...@@ -173,11 +173,31 @@ struct miopen_gemm
} }
}; };
struct miopen_transpose
{
transpose op;
std::string name() const { return "miopen::transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct miopen_contiguous struct miopen_contiguous
{ {
contiguous op; contiguous op;
std::string name() const { return "miopen::contiguous"; } std::string name() const { return "miopen::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -248,6 +268,10 @@ struct miopen_apply ...@@ -248,6 +268,10 @@ struct miopen_apply
{ {
apply_gemm(it); apply_gemm(it);
} }
else if(it->op.name() == "transpose")
{
apply_transpose(it);
}
else if(it->op.name() == "contiguous") else if(it->op.name() == "contiguous")
{ {
apply_contiguous(it); apply_contiguous(it);
...@@ -319,6 +343,13 @@ struct miopen_apply ...@@ -319,6 +343,13 @@ struct miopen_apply
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_transpose(instruction_ref ins)
{
auto&& op = any_cast<transpose>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_transpose{op}, ins->arguments.at(0), output);
}
void apply_contiguous(instruction_ref ins) void apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->op);
......
...@@ -163,7 +163,7 @@ struct test_contiguous ...@@ -163,7 +163,7 @@ struct test_contiguous
migraph::program create_program() const migraph::program create_program() const
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {32, 16, 128, 128}, {262144, 128, 1, 16384}}; migraph::shape s{migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraph::contiguous{}, x); p.add_instruction(migraph::contiguous{}, x);
return p; return p;
...@@ -172,30 +172,32 @@ struct test_contiguous ...@@ -172,30 +172,32 @@ struct test_contiguous
migraph::program::parameter_map create_params() const migraph::program::parameter_map create_params() const
{ {
migraph::program::parameter_map m; migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {32, 16, 128, 128}}); m["x"] =
migraph::generate_argument({migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}});
return m; return m;
} }
}; };
void contiguous_test() struct test_transpose
{ {
migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; migraph::program create_program() const
std::vector<float> data(12); {
std::iota(data.begin(), data.end(), 0); migraph::program p;
migraph::shape s{migraph::shape::float_type, {4, 3, 4, 4}};
auto x = p.add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = p.add_instruction(migraph::transpose{perm}, x);
p.add_instruction(migraph::contiguous{}, l);
return p;
}
migraph::program p; migraph::program::parameter_map create_params() const
auto l = p.add_literal(migraph::literal{a_shape, data}); {
p.add_instruction(migraph::contiguous{}, l); migraph::program::parameter_map m;
p.compile(migraph::miopen::miopen_target{}); m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 4, 4}});
auto result = p.eval({}); return m;
}
std::vector<float> results_vector(12); };
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(test::verify_range(results_vector, gold));
}
int main() int main()
{ {
...@@ -204,6 +206,6 @@ int main() ...@@ -204,6 +206,6 @@ int main()
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
// verify_program<test_contiguous>(); verify_program<test_contiguous>();
contiguous_test(); verify_program<test_transpose>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment