Commit aaf8b162 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the type of alpha and beta from int8 to int32 for quant_dot

parent f0f562b2
...@@ -18,8 +18,8 @@ namespace op { ...@@ -18,8 +18,8 @@ namespace op {
struct quant_dot struct quant_dot
{ {
int8_t alpha = 1; int32_t alpha = 1;
int8_t beta = 1; int32_t beta = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -119,7 +119,7 @@ void migemm( ...@@ -119,7 +119,7 @@ void migemm(
} }
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, int8_t alpha, int8_t beta) const argument& c_arg, const argument& a_arg, const argument& b_arg, int32_t alpha, int32_t beta)
{ {
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
} }
......
...@@ -11,7 +11,7 @@ namespace cpu { ...@@ -11,7 +11,7 @@ namespace cpu {
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, int8_t alpha, int8_t beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, int32_t alpha, int32_t beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -563,7 +563,7 @@ struct cpu_quant_gemm ...@@ -563,7 +563,7 @@ struct cpu_quant_gemm
} }
// 2 input arguments // 2 input arguments
int8_t beta = 0; int32_t beta = 0;
migemm(result, arg_0, arg_1, op.alpha, beta); migemm(result, arg_0, arg_1, op.alpha, beta);
return result; return result;
......
...@@ -94,7 +94,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -94,7 +94,7 @@ argument miopen_quant_gemm::compute(context& ctx,
} }
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
int8_t beta = 0; int32_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
beta = op.beta; beta = op.beta;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment