"...resnet50_tensorflow.git" did not exist on "96cbd362ce198b7cdf8bc8b0d0dd25b22c64aa6c"
Commit 69582971 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'extend_gemm_op' into seq2seq_example

parents 17984330 102d2eb6
...@@ -78,14 +78,14 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -78,14 +78,14 @@ void migemm_impl(tensor_view<T> cmat,
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
double s = cmat(c_idx.begin(), c_idx.end()) * beta;
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) { dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk; a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
}); });
cmat(c_idx.begin(), c_idx.end()) = alpha * s; cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
}); });
} }
......
...@@ -5,6 +5,30 @@ namespace migraphx { ...@@ -5,6 +5,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
...@@ -26,7 +50,7 @@ void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) ...@@ -26,7 +50,7 @@ void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("Type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
} }
template <class T> template <class T>
...@@ -83,28 +107,35 @@ argument miopen_gemm::compute(context& ctx, ...@@ -83,28 +107,35 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int m = output_shape.lens()[dim_0]; auto out_lens = output_shape.lens();
rocblas_int n = output_shape.lens()[dim_1]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto batch_num = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha)); auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemm(as, generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
ldb, ldb,
to_pointer(args[0]), k * n,
lda, to_pointer(args[0]),
&beta_r, lda,
to_pointer(args[2]), m * k,
ldc); &beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
}); });
......
...@@ -974,6 +974,144 @@ void gemm_test_ex() ...@@ -974,6 +974,144 @@ void gemm_test_ex()
TEST_CASE_REGISTER(gemm_test_ex<float>) TEST_CASE_REGISTER(gemm_test_ex<float>)
TEST_CASE_REGISTER(gemm_test_ex<double>) TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(gemm_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -850,6 +850,38 @@ struct test_gemm_transposeab ...@@ -850,6 +850,38 @@ struct test_gemm_transposeab
} }
}; };
struct gemm_mutli_dim_2
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_mutli_dim_2_3
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct test_contiguous struct test_contiguous
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3010,6 +3042,8 @@ int main() ...@@ -3010,6 +3042,8 @@ int main()
verify_program<test_gemm_transposea>(); verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposea_ex>(); verify_program<test_gemm_transposea_ex>();
verify_program<test_gemm_transposeab>(); verify_program<test_gemm_transposeab>();
verify_program<gemm_mutli_dim_2>();
verify_program<gemm_mutli_dim_2_3>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>(); verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment