Commit acfd1369 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in quant_dot implementation

parent da80ceb4
...@@ -96,7 +96,7 @@ struct miopen_apply ...@@ -96,7 +96,7 @@ struct miopen_apply
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("dot"); add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -63,7 +62,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -63,7 +62,7 @@ argument miopen_quant_gemm::compute(context& ctx,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
float beta = 0.0f; int8_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
beta = op.beta; beta = op.beta;
...@@ -88,8 +87,8 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -88,8 +87,8 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0); assert(k % 4 == 0);
assert(transa && (lda % 4 == 0)); assert(transa or (lda % 4 == 0));
assert(!transb && (ldb % 4 == 0)); assert(!transb or (ldb % 4 == 0));
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment