Commit 7c885dbb authored by Paul's avatar Paul
Browse files

Add test for miopen gemm

parent 84af2e9e
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <rtg/shape_for_each.hpp> #include <rtg/shape_for_each.hpp>
#include <rtg/miopen/miopen.hpp> #include <rtg/miopen/miopen.hpp>
#include <rtg/miopen/hip.hpp> #include <rtg/miopen/hip.hpp>
#include <rtg/dfor.hpp>
namespace rtg { namespace rtg {
namespace miopen { namespace miopen {
...@@ -140,6 +141,29 @@ struct miopen_add ...@@ -140,6 +141,29 @@ struct miopen_add
} }
}; };
struct miopen_gemm
{
gemm op;
std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)});
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))(
[&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0], input2.get_shape().lens()[1], input2.get_shape().lens()[0])([&](auto i, auto j, auto k) {
output(i, j) += input1(i, k) * input2(k, j);
});
});
return to_gpu(result);
}
};
struct miopen_relu struct miopen_relu
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
...@@ -194,6 +218,10 @@ struct miopen_apply ...@@ -194,6 +218,10 @@ struct miopen_apply
{ {
apply_add(it); apply_add(it);
} }
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
} }
} }
...@@ -253,6 +281,14 @@ struct miopen_apply ...@@ -253,6 +281,14 @@ struct miopen_apply
prog->replace_instruction( prog->replace_instruction(
ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_gemm(instruction_ref ins)
{
auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_gemm{op}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
}
}; };
std::string miopen_target::name() const { return "miopen"; } std::string miopen_target::name() const { return "miopen"; }
......
...@@ -136,10 +136,31 @@ struct test_conv_pooling ...@@ -136,10 +136,31 @@ struct test_conv_pooling
} }
}; };
struct test_gemm
{
rtg::program create_program() const
{
rtg::program p;
auto a = p.add_parameter("a", rtg::shape{rtg::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", rtg::shape{rtg::shape::float_type, {5, 3}});
p.add_instruction(rtg::gemm{}, a, b);
return p;
}
rtg::program::parameter_map create_params() const
{
rtg::program::parameter_map m;
m["a"] = rtg::generate_argument({rtg::shape::float_type, {4, 5}});
m["b"] = rtg::generate_argument({rtg::shape::float_type, {5, 3}});
return m;
}
};
int main() int main()
{ {
// verify_program<test_add>(); verify_program<test_add>();
verify_program<test_add_broadcast>(); verify_program<test_add_broadcast>();
// verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
// verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment