Commit 3ccd7e15 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup for the gemm implementation

parent b9e0366d
......@@ -819,10 +819,11 @@ struct gather
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vectors as input). If A or B is 3 or more dims, it is considered as a
// stack(batch) of matrices.
// Note that, we optimze the scenario of either the Matmul or Gemm operators,
// But for extensional scenarios like GEMM with three inputs, and each arg
// is a batch is matrices, the implementation may need further optimization
// later.
// Note that we only support the scenario of either the Matmul or Gemm
// operators. That is, if there are 3 inputs, we consider it is a Gemm, then
// A and B must be matrix inputs, and C is broadcastable to A * B. If there
// is only two inputs, A and B can be 1-dim to N-dim, in this case, there
// is no C input.
struct dot
{
float alpha = 1.0;
......@@ -844,25 +845,12 @@ struct dot
if(a.empty())
{
if(is_mutli_broadcast)
{
return b;
}
else
{
MIGRAPHX_THROW("DOT: C is not broadcastable to A * B (scalar)");
}
return b;
}
auto a_size = a.size();
auto b_size = b.size();
if(is_mutli_broadcast && b_size > a_size)
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) + "} is not broadcastable to A * b {" +
to_string_range(a) + "}");
}
auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for(std::size_t i = 0; i < n_dim; ++i)
......@@ -875,27 +863,15 @@ struct dot
{
out_lens[i] = a[a_size - 1 - i];
}
else if(a[a_size - 1 - i] == 1 && is_mutli_broadcast)
{
out_lens[i] = b[b_size - 1 - i];
}
else
{
if(a[a_size - 1 - i] == 1 && is_mutli_broadcast)
{
out_lens[i] = b[b_size - 1 - i];
}
else
{
if(is_mutli_broadcast)
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a) + "}, and matrix B: {" +
to_string_range(b) + "} are not broadcastable");
}
else
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) +
"}");
}
}
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a) + "}, and matrix B: {" +
to_string_range(b) + "} are not broadcastable");
}
}
......@@ -917,7 +893,42 @@ struct dot
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs[0], inputs[1]}, *this}.has(2).same_type();
// If there are 3 inputs, then A and B must be matrices and
// C is broadcastable to A * B
if (inputs.size() == 3)
{
check_shapes{inputs, *this}.has(3).same_type();
check_shapes{{inputs[0]}, *this}.only_dims(2);
check_shapes{{inputs[1]}, *this}.only_dims(2);
auto a_lens = inputs[0].lens();
auto b_lens = inputs[1].lens();
auto t = inputs[0].type();
if (a_lens[1] != b_lens[0])
{
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}");
}
auto out_lens = a_lens;
out_lens[0] = b_lens[0];
// check whether C is broadcastable to A * B
auto c_lens = inputs[2].lens();
if (c_lens.size() > 2 ||
(c_lens.size() >= 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[0])) ||
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != a_lens[1])))
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) +
"} is not broadcastable to A * B {" + to_string_range(out_lens) +
"}");
}
return {t, out_lens};
}
// For the case of two inputs, it is the numpy.matmul
check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -977,21 +988,6 @@ struct dot
out_lens.pop_back();
}
// c is unibroadcastable to A * B
if(inputs.size() == 3)
{
// same type as A and B
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
if(out_lens.empty() && (!inputs[2].scalar()))
{
MIGRAPHX_THROW("DOT: C is not broadcastable to A*B (scalar)");
}
// check c is broadcastable to A * B
auto c_lens = inputs[2].lens();
shape_broadcast(out_lens, c_lens, false);
}
if(out_lens.empty())
{
return {t};
......
......@@ -77,19 +77,19 @@ void migemm_impl(tensor_view<T> cmat,
auto b_lens = bmat.get_shape().lens();
auto c_lens = cmat.get_shape().lens();
std::size_t n_dims = c_lens.size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = a_lens[dim_1];
std::size_t nc_dims = c_lens.size();
std::size_t na_dims = a_lens.size();
std::size_t nb_dims = b_lens.size();
auto k = a_lens[na_dims - 1];
assert(a_lens[dim_1] == b_lens[dim_0]);
assert(c_lens[dim_0] == a_lens[dim_0]);
assert(c_lens[dim_1] == b_lens[dim_1]);
assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]);
assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]);
assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]);
std::size_t a_len_diff = c_lens.size() - a_lens.size();
std::size_t b_len_diff = c_lens.size() - b_lens.size();
std::vector<std::size_t> a_idx(a_lens.size());
std::vector<std::size_t> b_idx(b_lens.size());
std::size_t a_len_diff = nc_dims - na_dims;
std::size_t b_len_diff = nc_dims - nb_dims;
std::vector<std::size_t> a_idx(na_dims);
std::vector<std::size_t> b_idx(nb_dims);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
std::transform(c_lens.begin() + a_len_diff,
......@@ -105,7 +105,7 @@ void migemm_impl(tensor_view<T> cmat,
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
a_idx[na_dims - 1] = b_idx[nb_dims - 2] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
......
......@@ -371,14 +371,82 @@ struct cpu_gemm
std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
void fill_result(argument& result, argument& c) const
{
auto out_lens = result.get_shape().lens();
auto c_lens = c.get_shape().lens();
if (out_lens == c_lens)
{
visit_all(result, c)([&](auto output, auto input) {
std::memcpy(output.data(), input.data(), c_shape.bytes());
});
}
// need broadcast
else if (c.single())
{
visit_all(result, c)([&](auto output, auto input) {
std::fill(output.begin(), output.end(), input.front());
});
}
// must be c_lens[0] == output_lens[1]
else if (c_lens.size() == 1 ||
(c_lens.size() == 2 && (c_lens[1] == out_lens[1])))
{
std::size_t m = out_lens[0];
std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) {
for (std::size_t i = 0; i < m; i++)
{
std::memcpy((output.data() + i * n), input.data(), c_shape.bytes());
}
});
}
// c_lens.size() == 2 and c_lens[0] == out_lens[0]
else
{
std::size_t m = out_lens[0];
std::size_t n = out_lens[1];
visit_all(result, c)([&](auto output, auto input) {
for (std::size_t i = 0; i < m; i++)
{
std::fill(output.begin() + i * n,
(i + 1 == m) ? output.end() : output.begin() + ((i + 1) * n), input[i]);
}
});
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if (args.size() == 3)
{
// no need to consider the value of args[2]
if (op.beta == 0.0f)
{
result.visit([&](auto output) {
std::memset(output.data(), 0, output_shape.bytes());
});
}
else
{
fill_result(result, args[2]);
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input cases
// all args are scalar
if(output_shape.scalar())
{
visit_all(result, args[0], args[1], args[2])([&](auto ret, auto a, auto b, auto c) {
ret[0] = op.alpha * a[0] * b[0] + op.beta * c[0];
visit_all(result, args[0], args[1])([&](auto res, auto a, auto b) {
res[0] = op.alpha * a[0] * b[0];
});
return result;
......@@ -406,55 +474,11 @@ struct cpu_gemm
out_lens.push_back(1);
}
// if there is a C input
if(args.size() == 2)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
op.beta);
return result;
}
// 3 input arguments
auto c_shape = args[2].get_shape();
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
if(c_shape == output_shape)
{
// memory copy is more efficient than doing element by element
result.visit([&](auto output) {
args[2].visit(
[&](auto input) { std::memcpy(output.data(), input.data(), c_shape.bytes()); });
});
}
else
{
auto out_len = output_shape.lens();
auto c_lens = c_shape.lens();
std::size_t len_diff = out_len.size() - c_lens.size();
visit_all(result, args[2])([&](auto output, auto c) {
shape_for_each(output_shape, [&](auto out_idx) {
// compute the input index
std::vector<std::size_t> in_idx(c_lens.size());
std::transform(c_lens.begin(),
c_lens.end(),
out_len.begin() + len_diff,
in_idx.begin(),
[&](auto i, auto j) { return (i == 1) ? 0 : j; });
output(out_idx.begin(), out_idx.end()) = c(in_idx.begin(), in_idx.end());
});
});
}
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
op.beta);
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
0.0f);
return result;
}
......
......@@ -171,10 +171,75 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs);
}
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const
void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
const argument& result, const argument& c) const
{
auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens();
if (output_shape == c.get_shape())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) {
return to_rocblas_type(as.from(arg.data()));
};
hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
});
}
else if (c.single())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2]),
args[2].get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
else if (c_lens.size() == 1 ||
(c_lens.size() == 2 && c_lens[1] == out_lens[1]))
{
auto m = out_lens[0];
auto n = out_lens[1];
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < m; ++i)
{
hipMemcpy(to_pointer(args[3], i * n),
to_pointer(args[2]),
args[2].get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
// case of c_lens.size() == 2 && c_len[0] == out_lens[0]
else
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2], i / n),
args[2].get_shape().type_size(),
hipMemcpyDeviceToDevice);
}
});
}
}
argument miopen_gemm::compute(context& ctx,
......@@ -182,12 +247,51 @@ argument miopen_gemm::compute(context& ctx,
const std::vector<argument>& args) const
{
bool is_3inputs = (args.size() == 4);
if (is_3inputs)
{
fill_result(ctx, output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[0];
rocblas_int n = out_lens[1];
rocblas_int k = args[0].get_shape().lens()[1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
to_pointer(args[2]),
ldc);
});
return args[3];
}
// 2 input arguments cases
// vector inner product
if(output_shape.elements() == 1)
{
assert(args[0].get_shape().elements() == args[1].get_shape().elements());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as,
ctx.get_stream().get_rocblas(),
......@@ -196,129 +300,172 @@ argument miopen_gemm::compute(context& ctx,
1,
to_pointer(args[1]),
1,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
to_pointer(args[2]));
generic_rocblas_scal(as,
ctx.get_stream().get_rocblas(),
1,
&alpha_r,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
1);
if(is_3inputs)
{
generic_rocblas_axpy(as,
ctx.get_stream().get_rocblas(),
1,
&beta_r,
to_pointer(args[2]),
1,
to_pointer(args[3]),
1);
}
to_pointer(args[2]));
1);
});
return is_3inputs ? args[3] : args[2];
}
// b is a vector, so the computation is matrix * vector
// could not be the case of inner product of vectors since
// it is already processed above
if(args[1].get_shape().lens().size() == 1)
// matrix * vector
else if (args[1].get_shape().lens().size() == 1)
{
// considering the batch input, so A could be a batch
// of matrices
auto a_lens = args[0].get_shape().lens();
std::size_t n_dims = a_lens.size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
bool transa = args[0].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int m = a_lens[dim_0];
rocblas_int k = a_lens[dim_1];
auto batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto a_lens = args[0].get_shape().lens();
std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1;
bool trans = args[0].get_shape().transposed();
rocblas_int m = a_lens[trans ? dim_1 : dim_0];
rocblas_int n = a_lens[trans ? dim_0 : dim_1];
float beta = 0.0f;
rocblas_int lda = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
assert(a_lens.back() == args[1].get_shape().elements());
std::size_t batch_num = std::accumulate(a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{
if(is_3inputs)
hipMemcpy(to_pointer(args[3] + batch_no * m),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(),
trans ? rocblas_operation_transpose : rocblas_operation_none,
m,
n,
&alpha_r,
to_pointer(args[0], batch_no * m * n),
lda,
to_pointer(args[1]),
1,
&beta_r,
to_pointer(args[2], batch_no * n)
1);
}
});
}
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto batch_num = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
bool is_3inputs = (args.size() == 4);
// two input arguments
if(!is_3inputs)
// vector * matrix
else if (args[0].get_shape().lens().size() == 1)
{
}
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
if(is_3inputs)
hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
});
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_batched_gemm(as,
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = b_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 1;
bool trans = !args[1].get_shape().transposed();
rocblas_int m = b_lens[trans ? dim_1 : dim_0];
rocblas_int n = b_lens[trans ? dim_0 : dim_1];
float beta = 0.0f;
rocblas_int lda = args[1].get_shape().strides()[trans ? dim_1 : dim_0];
assert(b_lens.back() == args[0].get_shape().elements());
std::size_t batch_num = std::accumulate(b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{
generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
trans ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
to_pointer(args[1], batch_no * m * n),
1,
&beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
});
to_pointer(args[2], batch_no * m)
1);
}
});
}
// (batch) matrix multiplication
else
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
rocblas_int lda = args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
rocblas_int ldb = args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0};
for (axis = 2; axis < input_dims; ++axis)
{
if (a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
{
break;
}
}
// The number of matrices that can be computed in one call
// batch_num > 1, we need to call the batch_gemm function,
// otherwise, call the gemm function directly
std::size_t num_matrices = std::accumulate(a_lens.rbegin() + 2,
(axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
std::size_t{1}, std::multiplies<std::size_t>());
std::size_t a_len_diff = out_lens.size() - a_lens.size();
std::size_t b_len_diff = out_lens.size() - b_lens.size();
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + a_lens.size() - axis);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + b_lens.size() - axis);
std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + out_lens.size() - axis);
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_diff_shape{t, out_batch_lens};
shape_for_each(out_diff_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
std::vector<std::size_t> a_idx(a_lens.size() - axis);
std::vector<std::size_t> b_idx(b_lens.size() - axis);
std::transform(out_idx.begin() + a_len_diff, out_idx.end(), a_batch_lens.begin(), a_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
std::transform(out_idx.begin() + b_len_diff, out_idx.end(), b_batch_lens.begin(), b_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * num_matrices * b_ind),
ldb,
k * n,
to_pointer(args[0], m * k * num_matrices * a_ind),
lda,
m * k,
&beta_r,
to_pointer(args[2], m * n * num_matrices * out_ind),
ldc,
m * n,
num_matrices);
});
});
}
return (is_3inputs ? args[3] : args[2]);
return args[2];
}
} // namespace gpu
......
......@@ -20,9 +20,7 @@ struct miopen_gemm
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const;
void fill_result(context& ctx, const shape& output_shape, const argument& result, const argument& c) const;
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment