Commit 5c5115af authored by Paul's avatar Paul
Browse files

Add support for transpose gemm using blaze

parent cd07476b
...@@ -324,7 +324,7 @@ struct gemm ...@@ -324,7 +324,7 @@ struct gemm
auto t = a.type(); auto t = a.type();
if(a.lens()[1] != b.lens()[0]) if(a.lens()[1] != b.lens()[0])
MIGRAPH_THROW("Inner dimensions do not match"); MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
......
...@@ -159,15 +159,17 @@ struct miopen_gemm ...@@ -159,15 +159,17 @@ struct miopen_gemm
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1]; bool transa = args[0].get_shape().transposed();
rocblas_int ldb = args[1].get_shape().lens()[1]; bool transb = args[1].get_shape().transposed();
rocblas_int ldc = args[2].get_shape().lens()[1]; rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0];
rocblas_int m = output_shape.lens()[0]; rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.rbhandle.get(), rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
......
...@@ -135,6 +135,46 @@ struct test_gemm ...@@ -135,6 +135,46 @@ struct test_gemm
} }
}; };
struct test_gemm_transposeb
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, a, bt);
return p;
}
};
struct test_gemm_transposea
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
p.add_instruction(migraph::gemm{}, at, b);
return p;
}
};
struct test_gemm_transposeab
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, at, bt);
return p;
}
};
struct test_contiguous struct test_contiguous
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -168,6 +208,9 @@ int main() ...@@ -168,6 +208,9 @@ int main()
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment