Commit acfd1369 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in quant_dot implementation

parent da80ceb4
......@@ -96,7 +96,7 @@ struct miopen_apply
add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
......
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
......@@ -63,7 +62,7 @@ argument miopen_quant_gemm::compute(context& ctx,
const std::vector<argument>& args) const
{
bool is_3inputs = (args.size() == 4);
float beta = 0.0f;
int8_t beta = 0;
if(is_3inputs)
{
beta = op.beta;
......@@ -88,8 +87,8 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0);
assert(transa && (lda % 4 == 0));
assert(!transb && (ldb % 4 == 0));
assert(transa or (lda % 4 == 0));
assert(!transb or (ldb % 4 == 0));
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment