Commit 3e6a9c17 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more change for the quant_dot implementation

parent c47fd466
...@@ -111,6 +111,16 @@ void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, ...@@ -111,6 +111,16 @@ void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg,
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
} }
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, int8_t alpha, int8_t beta)
{
migemm(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -8,8 +8,8 @@ namespace migraphx { ...@@ -8,8 +8,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template <class T> void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta); void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, int8_t alpha, int8_t beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <migraphx/cpu/quant_gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -877,7 +876,7 @@ struct cpu_apply ...@@ -877,7 +876,7 @@ struct cpu_apply
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>(); //apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment