Commit 573a935c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in the gpu implementation of the quant_dot operator.

parent 54134027
......@@ -87,33 +87,26 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0);
assert(transa or (lda % 4 == 0));
assert(!transb or (ldb % 4 == 0));
auto arg_0 = migraphx::gpu::from_gpu(args[0]);
auto arg_1 = migraphx::gpu::from_gpu(args[1]);
auto arg_2 = migraphx::gpu::from_gpu(args[2]);
std::cout << "arg_0 = " << arg_0 << std::endl;
std::cout << "arg_1 = " << arg_1 << std::endl;
std::cout << "arg_2 = " << arg_2 << std::endl;
assert(!transa or (lda % 4 == 0));
assert(transb or (ldb % 4 == 0));
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none,
m,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[0]),
rocblas_datatype_i8_r,
lda,
to_pointer(args[1]),
rocblas_datatype_i8_r,
ldb,
to_pointer(args[0]),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment