Commit 52b9cf14 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3ccd7e15
...@@ -819,7 +819,7 @@ struct gather ...@@ -819,7 +819,7 @@ struct gather
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of // vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vectors as input). If A or B is 3 or more dims, it is considered as a // vectors as input). If A or B is 3 or more dims, it is considered as a
// stack(batch) of matrices. // stack(batch) of matrices.
// Note that we only support the scenario of either the Matmul or Gemm // Note that we only support the scenario of either the Matmul or Gemm
// operators. That is, if there are 3 inputs, we consider it is a Gemm, then // operators. That is, if there are 3 inputs, we consider it is a Gemm, then
// A and B must be matrix inputs, and C is broadcastable to A * B. If there // A and B must be matrix inputs, and C is broadcastable to A * B. If there
// is only two inputs, A and B can be 1-dim to N-dim, in this case, there // is only two inputs, A and B can be 1-dim to N-dim, in this case, there
...@@ -869,9 +869,9 @@ struct dot ...@@ -869,9 +869,9 @@ struct dot
} }
else else
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
to_string_range(a) + "}, and matrix B: {" + "}, and matrix B: {" + to_string_range(b) +
to_string_range(b) + "} are not broadcastable"); "} are not broadcastable");
} }
} }
...@@ -895,7 +895,7 @@ struct dot ...@@ -895,7 +895,7 @@ struct dot
{ {
// If there are 3 inputs, then A and B must be matrices and // If there are 3 inputs, then A and B must be matrices and
// C is broadcastable to A * B // C is broadcastable to A * B
if (inputs.size() == 3) if(inputs.size() == 3)
{ {
check_shapes{inputs, *this}.has(3).same_type(); check_shapes{inputs, *this}.has(3).same_type();
check_shapes{{inputs[0]}, *this}.only_dims(2); check_shapes{{inputs[0]}, *this}.only_dims(2);
...@@ -903,25 +903,25 @@ struct dot ...@@ -903,25 +903,25 @@ struct dot
auto a_lens = inputs[0].lens(); auto a_lens = inputs[0].lens();
auto b_lens = inputs[1].lens(); auto b_lens = inputs[1].lens();
auto t = inputs[0].type(); auto t = inputs[0].type();
if (a_lens[1] != b_lens[0]) if(a_lens[1] != b_lens[0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}"); "}, cannot multiply operand B: {" + to_string_range(b_lens) + "}");
} }
auto out_lens = a_lens; auto out_lens = a_lens;
out_lens[0] = b_lens[0]; out_lens[0] = b_lens[0];
// check whether C is broadcastable to A * B // check whether C is broadcastable to A * B
auto c_lens = inputs[2].lens(); auto c_lens = inputs[2].lens();
if (c_lens.size() > 2 || if(c_lens.size() > 2 ||
(c_lens.size() >= 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[0])) || (c_lens.size() >= 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[0])) ||
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != a_lens[1]))) (c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != a_lens[1])))
{ {
MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) + MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) +
"} is not broadcastable to A * B {" + to_string_range(out_lens) + "} is not broadcastable to A * B {" + to_string_range(out_lens) +
"}"); "}");
} }
return {t, out_lens}; return {t, out_lens};
......
...@@ -80,7 +80,7 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -80,7 +80,7 @@ void migemm_impl(tensor_view<T> cmat,
std::size_t nc_dims = c_lens.size(); std::size_t nc_dims = c_lens.size();
std::size_t na_dims = a_lens.size(); std::size_t na_dims = a_lens.size();
std::size_t nb_dims = b_lens.size(); std::size_t nb_dims = b_lens.size();
auto k = a_lens[na_dims - 1]; auto k = a_lens[na_dims - 1];
assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]); assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]);
assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]); assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]);
......
...@@ -374,29 +374,28 @@ struct cpu_gemm ...@@ -374,29 +374,28 @@ struct cpu_gemm
void fill_result(argument& result, argument& c) const void fill_result(argument& result, argument& c) const
{ {
auto out_lens = result.get_shape().lens(); auto out_lens = result.get_shape().lens();
auto c_lens = c.get_shape().lens(); auto c_lens = c.get_shape().lens();
if (out_lens == c_lens) if(out_lens == c_lens)
{ {
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
std::memcpy(output.data(), input.data(), c_shape.bytes()); std::memcpy(output.data(), input.data(), c_shape.bytes());
}); });
} }
// need broadcast // need broadcast
else if (c.single()) else if(c.single())
{ {
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
std::fill(output.begin(), output.end(), input.front()); std::fill(output.begin(), output.end(), input.front());
}); });
} }
// must be c_lens[0] == output_lens[1] // must be c_lens[0] == output_lens[1]
else if (c_lens.size() == 1 || else if(c_lens.size() == 1 || (c_lens.size() == 2 && (c_lens[1] == out_lens[1])))
(c_lens.size() == 2 && (c_lens[1] == out_lens[1])))
{ {
std::size_t m = out_lens[0]; std::size_t m = out_lens[0];
std::size_t n = out_lens[1]; std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
for (std::size_t i = 0; i < m; i++) for(std::size_t i = 0; i < m; i++)
{ {
std::memcpy((output.data() + i * n), input.data(), c_shape.bytes()); std::memcpy((output.data() + i * n), input.data(), c_shape.bytes());
} }
...@@ -408,12 +407,13 @@ struct cpu_gemm ...@@ -408,12 +407,13 @@ struct cpu_gemm
std::size_t m = out_lens[0]; std::size_t m = out_lens[0];
std::size_t n = out_lens[1]; std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
for (std::size_t i = 0; i < m; i++) for(std::size_t i = 0; i < m; i++)
{ {
std::fill(output.begin() + i * n, std::fill(output.begin() + i * n,
(i + 1 == m) ? output.end() : output.begin() + ((i + 1) * n), input[i]); (i + 1 == m) ? output.end() : output.begin() + ((i + 1) * n),
input[i]);
} }
}); });
} }
} }
...@@ -422,18 +422,17 @@ struct cpu_gemm ...@@ -422,18 +422,17 @@ struct cpu_gemm
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then // 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B // A and B are matrics, and C is broadcastable to A * B
if (args.size() == 3) if(args.size() == 3)
{ {
// no need to consider the value of args[2] // no need to consider the value of args[2]
if (op.beta == 0.0f) if(op.beta == 0.0f)
{ {
result.visit([&](auto output) { result.visit(
std::memset(output.data(), 0, output_shape.bytes()); [&](auto output) { std::memset(output.data(), 0, output_shape.bytes()); });
});
} }
else else
{ {
fill_result(result, args[2]); fill_result(result, args[2]);
} }
migemm(result, args[0], args[1], op.alpha, op.beta); migemm(result, args[0], args[1], op.alpha, op.beta);
...@@ -445,9 +444,8 @@ struct cpu_gemm ...@@ -445,9 +444,8 @@ struct cpu_gemm
// all args are scalar // all args are scalar
if(output_shape.scalar()) if(output_shape.scalar())
{ {
visit_all(result, args[0], args[1])([&](auto res, auto a, auto b) { visit_all(result, args[0], args[1])(
res[0] = op.alpha * a[0] * b[0]; [&](auto res, auto a, auto b) { res[0] = op.alpha * a[0] * b[0]; });
});
return result; return result;
} }
...@@ -475,10 +473,10 @@ struct cpu_gemm ...@@ -475,10 +473,10 @@ struct cpu_gemm
} }
migemm({{t, out_lens}, result.data()}, migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()}, {{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()}, {{t, b_lens}, args[1].data()},
op.alpha, op.alpha,
0.0f); 0.0f);
return result; return result;
} }
......
...@@ -171,24 +171,24 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -171,24 +171,24 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
void miopen_gemm::fill_result(context& ctx, const shape& output_shape, void miopen_gemm::fill_result(context& ctx,
const argument& result, const argument& c) const const shape& output_shape,
const argument& result,
const argument& c) const
{ {
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens(); auto c_lens = c.get_shape().lens();
if (output_shape == c.get_shape()) if(output_shape == c.get_shape())
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
return to_rocblas_type(as.from(arg.data()));
};
hipMemcpy(to_pointer(args[3]), hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]), to_pointer(args[2]),
output_shape.bytes(), output_shape.bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
}); });
} }
else if (c.single()) else if(c.single())
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) { auto to_pointer = [&](auto&& arg, std::size_t offset) {
...@@ -198,14 +198,13 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape, ...@@ -198,14 +198,13 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(args[3], i), hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2]), to_pointer(args[2]),
args[2].get_shape().bytes(), args[2].get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
} }
else if (c_lens.size() == 1 || else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
(c_lens.size() == 2 && c_lens[1] == out_lens[1]))
{ {
auto m = out_lens[0]; auto m = out_lens[0];
auto n = out_lens[1]; auto n = out_lens[1];
...@@ -217,9 +216,9 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape, ...@@ -217,9 +216,9 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
for(std::size_t i = 0; i < m; ++i) for(std::size_t i = 0; i < m; ++i)
{ {
hipMemcpy(to_pointer(args[3], i * n), hipMemcpy(to_pointer(args[3], i * n),
to_pointer(args[2]), to_pointer(args[2]),
args[2].get_shape().bytes(), args[2].get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
} }
...@@ -234,9 +233,9 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape, ...@@ -234,9 +233,9 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(args[3], i), hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2], i / n), to_pointer(args[2], i / n),
args[2].get_shape().type_size(), args[2].get_shape().type_size(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
} }
...@@ -247,22 +246,22 @@ argument miopen_gemm::compute(context& ctx, ...@@ -247,22 +246,22 @@ argument miopen_gemm::compute(context& ctx,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
if (is_3inputs) if(is_3inputs)
{ {
fill_result(ctx, output_shape, args[3], args[2]); fill_result(ctx, output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0]; rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0]; rocblas_int ldc = args[2].get_shape().strides()[0];
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[0]; rocblas_int m = out_lens[0];
rocblas_int n = out_lens[1]; rocblas_int n = out_lens[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemm(as, generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -302,33 +301,33 @@ argument miopen_gemm::compute(context& ctx, ...@@ -302,33 +301,33 @@ argument miopen_gemm::compute(context& ctx,
1, 1,
to_pointer(args[2])); to_pointer(args[2]));
generic_rocblas_scal(as, generic_rocblas_scal(
ctx.get_stream().get_rocblas(), as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]));
1,
&alpha_r,
to_pointer(args[2]));
1); 1);
}); });
} }
// matrix * vector // matrix * vector
else if (args[1].get_shape().lens().size() == 1) else if(args[1].get_shape().lens().size() == 1)
{ {
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
std::size_t dim_0 = a_lens.size() - 2; std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1; std::size_t dim_1 = a_lens.size() - 1;
bool trans = args[0].get_shape().transposed(); bool trans = args[0].get_shape().transposed();
rocblas_int m = a_lens[trans ? dim_1 : dim_0]; rocblas_int m = a_lens[trans ? dim_1 : dim_0];
rocblas_int n = a_lens[trans ? dim_0 : dim_1]; rocblas_int n = a_lens[trans ? dim_0 : dim_1];
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().strides()[trans ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
assert(a_lens.back() == args[1].get_shape().elements()); assert(a_lens.back() == args[1].get_shape().elements());
std::size_t batch_num = std::accumulate(a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); std::size_t batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no) return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ {
generic_rocblas_gemv(as, generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -341,30 +340,32 @@ argument miopen_gemm::compute(context& ctx, ...@@ -341,30 +340,32 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * n) to_pointer(args[2], batch_no * n) 1);
1);
} }
}); });
} }
// vector * matrix // vector * matrix
else if (args[0].get_shape().lens().size() == 1) else if(args[0].get_shape().lens().size() == 1)
{ {
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = b_lens.size() - 2; std::size_t dim_0 = b_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 1; std::size_t dim_1 = b_lens.size() - 1;
bool trans = !args[1].get_shape().transposed(); bool trans = !args[1].get_shape().transposed();
rocblas_int m = b_lens[trans ? dim_1 : dim_0]; rocblas_int m = b_lens[trans ? dim_1 : dim_0];
rocblas_int n = b_lens[trans ? dim_0 : dim_1]; rocblas_int n = b_lens[trans ? dim_0 : dim_1];
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[1].get_shape().strides()[trans ? dim_1 : dim_0]; rocblas_int lda = args[1].get_shape().strides()[trans ? dim_1 : dim_0];
assert(b_lens.back() == args[0].get_shape().elements()); assert(b_lens.back() == args[0].get_shape().elements());
std::size_t batch_num = std::accumulate(b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no) return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ {
generic_rocblas_gemv(as, generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -377,47 +378,53 @@ argument miopen_gemm::compute(context& ctx, ...@@ -377,47 +378,53 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1], batch_no * m * n), to_pointer(args[1], batch_no * m * n),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * m) to_pointer(args[2], batch_no * m) 1);
1);
} }
}); });
} }
// (batch) matrix multiplication // (batch) matrix multiplication
else else
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int lda = args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2]; rocblas_int lda =
rocblas_int ldb = args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2]; args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2]; rocblas_int ldb =
rocblas_int m = out_lens[out_lens.size() - 2]; args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1]; rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1]; rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
auto input_dims = std::min(a_lens.size(), b_lens.size()); auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0}; std::size_t axis{0};
for (axis = 2; axis < input_dims; ++axis) for(axis = 2; axis < input_dims; ++axis)
{ {
if (a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis]) if(a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
{ {
break; break;
} }
} }
// The number of matrices that can be computed in one call // The number of matrices that can be computed in one call
// batch_num > 1, we need to call the batch_gemm function, // batch_num > 1, we need to call the batch_gemm function,
// otherwise, call the gemm function directly // otherwise, call the gemm function directly
std::size_t num_matrices = std::accumulate(a_lens.rbegin() + 2, std::size_t num_matrices =
(axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis), std::accumulate(a_lens.rbegin() + 2,
std::size_t{1}, std::multiplies<std::size_t>()); (axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t a_len_diff = out_lens.size() - a_lens.size(); std::size_t a_len_diff = out_lens.size() - a_lens.size();
std::size_t b_len_diff = out_lens.size() - b_lens.size(); std::size_t b_len_diff = out_lens.size() - b_lens.size();
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + a_lens.size() - axis); std::vector<std::size_t> a_batch_lens(a_lens.begin(),
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + b_lens.size() - axis); a_lens.begin() + a_lens.size() - axis);
std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + out_lens.size() - axis); std::vector<std::size_t> b_batch_lens(b_lens.begin(),
b_lens.begin() + b_lens.size() - axis);
std::vector<std::size_t> out_batch_lens(out_lens.begin(),
out_lens.begin() + out_lens.size() - axis);
shape::type_t t = output_shape.type(); shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens}; shape a_batch_shape{t, a_batch_lens};
...@@ -428,39 +435,46 @@ argument miopen_gemm::compute(context& ctx, ...@@ -428,39 +435,46 @@ argument miopen_gemm::compute(context& ctx,
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end()); std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
std::vector<std::size_t> a_idx(a_lens.size() - axis); std::vector<std::size_t> a_idx(a_lens.size() - axis);
std::vector<std::size_t> b_idx(b_lens.size() - axis); std::vector<std::size_t> b_idx(b_lens.size() - axis);
std::transform(out_idx.begin() + a_len_diff, out_idx.end(), a_batch_lens.begin(), a_idx.begin(), [&](auto i, auto j) { std::transform(out_idx.begin() + a_len_diff,
return (j == 1) ? 0 : i; out_idx.end(),
}); a_batch_lens.begin(),
std::transform(out_idx.begin() + b_len_diff, out_idx.end(), b_batch_lens.begin(), b_idx.begin(), [&](auto i, auto j) { a_idx.begin(),
return (j == 1) ? 0 : i; [&](auto i, auto j) { return (j == 1) ? 0 : i; });
}); std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end()); std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end()); std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
generic_rocblas_batched_gemm(as, return to_rocblas_type(as.from(arg.data() + offset));
ctx.get_stream().get_rocblas(), };
transb ? rocblas_operation_transpose : rocblas_operation_none, generic_rocblas_batched_gemm(
transa ? rocblas_operation_transpose : rocblas_operation_none, as,
n, ctx.get_stream().get_rocblas(),
m, transb ? rocblas_operation_transpose : rocblas_operation_none,
k, transa ? rocblas_operation_transpose : rocblas_operation_none,
&alpha_r, n,
to_pointer(args[1], k * n * num_matrices * b_ind), m,
ldb, k,
k * n, &alpha_r,
to_pointer(args[0], m * k * num_matrices * a_ind), to_pointer(args[1], k * n * num_matrices * b_ind),
lda, ldb,
m * k, k * n,
&beta_r, to_pointer(args[0], m * k * num_matrices * a_ind),
to_pointer(args[2], m * n * num_matrices * out_ind), lda,
ldc, m * k,
m * n, &beta_r,
num_matrices); to_pointer(args[2], m * n * num_matrices * out_ind),
ldc,
m * n,
num_matrices);
}); });
}); });
} }
......
...@@ -20,7 +20,10 @@ struct miopen_gemm ...@@ -20,7 +20,10 @@ struct miopen_gemm
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
void fill_result(context& ctx, const shape& output_shape, const argument& result, const argument& c) const; void fill_result(context& ctx,
const shape& output_shape,
const argument& result,
const argument& c) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment