Commit 6ec90d65 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup after testing a few scenarios.

parent b9d45e76
...@@ -894,7 +894,7 @@ struct dot ...@@ -894,7 +894,7 @@ struct dot
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// If there are 3 inputs, then A and B must be matrices and // If there are 3 inputs, then A and B must be matrices and
// C is broadcastable to A * B // C should be broadcastable to A * B
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
check_shapes{inputs, *this}.has(3).same_type(); check_shapes{inputs, *this}.has(3).same_type();
...@@ -911,13 +911,14 @@ struct dot ...@@ -911,13 +911,14 @@ struct dot
} }
auto out_lens = a_lens; auto out_lens = a_lens;
out_lens[0] = b_lens[0]; out_lens[1] = b_lens[1];
// check whether C is broadcastable to A * B // check whether C is broadcastable to A * B
auto c_lens = inputs[2].lens(); auto c_lens = inputs[2].lens();
if(c_lens.size() > 2 || if(c_lens.size() > 2 ||
(c_lens.size() >= 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[0])) || (c_lens.size() == 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[1])) ||
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != a_lens[1]))) (c_lens.size() == 2 && (c_lens[0] != 1 && c_lens[0] != a_lens[0])) ||
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != b_lens[1])))
{ {
MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) + MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) +
"} is not broadcastable to A * B {" + to_string_range(out_lens) + "} is not broadcastable to A * B {" + to_string_range(out_lens) +
...@@ -955,9 +956,9 @@ struct dot ...@@ -955,9 +956,9 @@ struct dot
is_b_appended = true; is_b_appended = true;
} }
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 2; std::size_t dim_1 = b_lens.size() - 1;
if(a_lens[dim_0] != b_lens[dim_1]) if(a_lens[dim_1] != b_lens[dim_0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a.lens()) +
"}, cannot multiply operand B: {" + to_string_range(b.lens()) + "}"); "}, cannot multiply operand B: {" + to_string_range(b.lens()) + "}");
......
...@@ -168,7 +168,8 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal ...@@ -168,7 +168,8 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
return op.compute_shape(inputs); std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
return op.compute_shape(orig_inputs);
} }
void miopen_gemm::fill_result(const shape& output_shape, void miopen_gemm::fill_result(const shape& output_shape,
...@@ -177,6 +178,7 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -177,6 +178,7 @@ void miopen_gemm::fill_result(const shape& output_shape,
{ {
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens(); auto c_lens = c.get_shape().lens();
auto type_size = output_shape.type_size();
if(output_shape == c.get_shape()) if(output_shape == c.get_shape())
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
...@@ -188,13 +190,13 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -188,13 +190,13 @@ void miopen_gemm::fill_result(const shape& output_shape,
else if(c.single()) else if(c.single())
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset_byte));
}; };
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(result, i), hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c), to_pointer(c),
c.get_shape().bytes(), c.get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
...@@ -212,7 +214,7 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -212,7 +214,7 @@ void miopen_gemm::fill_result(const shape& output_shape,
for(std::size_t i = 0; i < m; ++i) for(std::size_t i = 0; i < m; ++i)
{ {
hipMemcpy(to_pointer(result, i * n), hipMemcpy(to_pointer(result, i * n * type_size),
to_pointer(c), to_pointer(c),
c.get_shape().bytes(), c.get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
...@@ -229,9 +231,9 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -229,9 +231,9 @@ void miopen_gemm::fill_result(const shape& output_shape,
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(result, i), hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c, i / out_lens[0]), to_pointer(c, i / out_lens[1] * type_size),
c.get_shape().type_size(), type_size,
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
...@@ -254,12 +256,28 @@ argument miopen_gemm::compute(context& ctx, ...@@ -254,12 +256,28 @@ argument miopen_gemm::compute(context& ctx,
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0]; rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0]; rocblas_int ldc = args[3].get_shape().strides()[0];
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[0]; rocblas_int m = out_lens[0];
rocblas_int n = out_lens[1]; rocblas_int n = out_lens[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[3]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
std::cout << "gpu::gemm, transb = " << transb << std::endl;
std::cout << "gpu::gemm, transa = " << transb << std::endl;
std::cout << "gpu::gemm, m = " << m << std::endl;
std::cout << "gpu::gemm, n = " << n << std::endl;
std::cout << "gpu::gemm, k = " << k << std::endl;
std::cout << "gpu::gemm, lda = " << lda << std::endl;
std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
std::cout << "gpu::gemm, ldc = " << ldc << std::endl;
generic_rocblas_gemm(as, generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -273,7 +291,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -273,7 +291,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[0]), to_pointer(args[0]),
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[3]),
ldc); ldc);
}); });
...@@ -283,6 +301,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -283,6 +301,7 @@ argument miopen_gemm::compute(context& ctx,
// 2 input arguments cases // 2 input arguments cases
// vector inner product // vector inner product
std::size_t type_size = output_shape.type_size();
if(output_shape.elements() == 1) if(output_shape.elements() == 1)
{ {
assert(args[0].get_shape().elements() == args[1].get_shape().elements()); assert(args[0].get_shape().elements() == args[1].get_shape().elements());
...@@ -331,12 +350,12 @@ argument miopen_gemm::compute(context& ctx, ...@@ -331,12 +350,12 @@ argument miopen_gemm::compute(context& ctx,
m, m,
n, n,
&alpha_r, &alpha_r,
to_pointer(args[0], batch_no * m * n), to_pointer(args[0], batch_no * m * n * type_size),
lda, lda,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * n), to_pointer(args[2], batch_no * n * type_size),
1); 1);
} }
}); });
...@@ -344,41 +363,74 @@ argument miopen_gemm::compute(context& ctx, ...@@ -344,41 +363,74 @@ argument miopen_gemm::compute(context& ctx,
// vector * matrix // vector * matrix
else if(args[0].get_shape().lens().size() == 1) else if(args[0].get_shape().lens().size() == 1)
{ {
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = b_lens.size() - 2; std::size_t dim_0 = b_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 1; std::size_t dim_1 = b_lens.size() - 1;
bool trans = !args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
rocblas_int m = b_lens[trans ? dim_1 : dim_0]; bool transa = false;
rocblas_int n = b_lens[trans ? dim_0 : dim_1]; rocblas_int lda = a_lens[0];
rocblas_int ldb = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
rocblas_int ldc = b_lens[dim_1];
rocblas_int m = 1;
rocblas_int n = args[1].get_shape().lens()[dim_1];
rocblas_int k = a_lens[0];
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[1].get_shape().strides()[trans ? dim_1 : dim_0]; assert(b_lens[dim_0] == args[0].get_shape().elements());
assert(b_lens.back() == args[0].get_shape().elements());
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[2]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
std::cout << "gpu::gemm, transb = " << transb << std::endl;
std::cout << "gpu::gemm, transa = " << transb << std::endl;
std::cout << "gpu::gemm, m = " << m << std::endl;
std::cout << "gpu::gemm, n = " << n << std::endl;
std::cout << "gpu::gemm, k = " << k << std::endl;
std::cout << "gpu::gemm, lda = " << lda << std::endl;
std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
std::cout << "gpu::gemm, ldc = " << ldc << std::endl;
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
});
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data()));
}; };
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ generic_rocblas_batched_gemm(
generic_rocblas_gemv(as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
trans ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
n, transa ? rocblas_operation_transpose : rocblas_operation_none,
m, n,
&alpha_r, m,
to_pointer(args[0]), k,
lda, &alpha_r,
to_pointer(args[1], batch_no * m * n), to_pointer(args[1]),
1, ldb,
&beta_r, k * n,
to_pointer(args[2], batch_no * m), to_pointer(args[0]),
1); lda,
} 0,
&beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
}); });
return args[2];
} }
// (batch) matrix multiplication // (batch) matrix multiplication
else else
...@@ -465,14 +517,14 @@ argument miopen_gemm::compute(context& ctx, ...@@ -465,14 +517,14 @@ argument miopen_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1], k * n * num_matrices * b_ind), to_pointer(args[1], k * n * num_matrices * b_ind * type_size),
ldb, ldb,
k * n, k * n,
to_pointer(args[0], m * k * num_matrices * a_ind), to_pointer(args[0], m * k * num_matrices * a_ind * type_size),
lda, lda,
m * k, m * k,
&beta_r, &beta_r,
to_pointer(args[2], m * n * num_matrices * out_ind), to_pointer(args[2], m * n * num_matrices * out_ind * type_size),
ldc, ldc,
m * n, m * n,
num_matrices); num_matrices);
...@@ -489,12 +541,12 @@ argument miopen_gemm::compute(context& ctx, ...@@ -489,12 +541,12 @@ argument miopen_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1], k * n * num_matrices * b_ind), to_pointer(args[1], k * n * b_ind * type_size),
ldb, ldb,
to_pointer(args[0], m * k * num_matrices * a_ind), to_pointer(args[0], m * k * a_ind * type_size),
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2], m * n * num_matrices * out_ind), to_pointer(args[2], m * n * out_ind * type_size),
ldc); ldc);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment