"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "9b219f046061e0246a5c3fd022383963f9d57743"
Commit 573a935c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in the gpu implementation of the quant_dot operator.

parent 54134027
...@@ -87,33 +87,26 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -87,33 +87,26 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0); assert(k % 4 == 0);
assert(transa or (lda % 4 == 0)); assert(!transa or (lda % 4 == 0));
assert(!transb or (ldb % 4 == 0)); assert(transb or (ldb % 4 == 0));
auto arg_0 = migraphx::gpu::from_gpu(args[0]);
auto arg_1 = migraphx::gpu::from_gpu(args[1]);
auto arg_2 = migraphx::gpu::from_gpu(args[2]);
std::cout << "arg_0 = " << arg_0 << std::endl;
std::cout << "arg_1 = " << arg_1 << std::endl;
std::cout << "arg_2 = " << arg_2 << std::endl;
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1)
{ {
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
m, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[0]),
rocblas_datatype_i8_r,
lda,
to_pointer(args[1]), to_pointer(args[1]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args[0]),
rocblas_datatype_i8_r,
lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment