Commit 5c5115af authored by Paul's avatar Paul
Browse files

Add support for transpose gemm using blaze

parent cd07476b
......@@ -324,7 +324,7 @@ struct gemm
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
MIGRAPH_THROW("Inner dimensions do not match");
MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
}
......
......@@ -159,15 +159,17 @@ struct miopen_gemm
{
float alpha = 1.0f;
float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1];
rocblas_int ldb = args[1].get_shape().lens()[1];
rocblas_int ldc = args[2].get_shape().lens()[1];
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0];
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none,
rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
......
......@@ -135,6 +135,46 @@ struct test_gemm
}
};
struct test_gemm_transposeb
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, a, bt);
return p;
}
};
struct test_gemm_transposea
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
p.add_instruction(migraph::gemm{}, at, b);
return p;
}
};
struct test_gemm_transposeab
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, at, bt);
return p;
}
};
struct test_contiguous
{
migraph::program create_program() const
......@@ -168,6 +208,9 @@ int main()
verify_program<test_conv_relu>();
verify_program<test_conv_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment