Commit 980ca67d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the quant_gemm operator for int8 type support in dot operator

parent 849f7d92
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_dot
{
int8_t alpha = 1;
int8_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if (t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if (inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -44,12 +44,12 @@ struct is_fast_gemm_type<float> : std::true_type ...@@ -44,12 +44,12 @@ struct is_fast_gemm_type<float> : std::true_type
{ {
}; };
template <class T> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat, tensor_view<T> amat,
tensor_view<T> bmat, tensor_view<T> bmat,
float alpha, F alpha,
float beta, F beta,
std::true_type) std::true_type)
{ {
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
...@@ -66,12 +66,12 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -66,12 +66,12 @@ void migemm_impl(tensor_view<T> cmat,
}); });
} }
template <class T> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat, tensor_view<T> amat,
tensor_view<T> bmat, tensor_view<T> bmat,
float alpha, F alpha,
float beta, F beta,
std::false_type) std::false_type)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
...@@ -95,9 +95,9 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -95,9 +95,9 @@ void migemm_impl(tensor_view<T> cmat,
}); });
} }
template <class T> template <class T, class F>
void migemm_impl( void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta) tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
...@@ -113,8 +113,9 @@ void migemm_impl( ...@@ -113,8 +113,9 @@ void migemm_impl(
} }
} }
template<class F>
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{ {
visit_all(c_arg, a_arg, b_arg)( visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
......
...@@ -8,8 +8,9 @@ namespace migraphx { ...@@ -8,8 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template<class T>
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <migraphx/cpu/quant_gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -383,7 +384,7 @@ struct cpu_gemm ...@@ -383,7 +384,7 @@ struct cpu_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then // 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B // A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3) if(args.size() == 3)
{ {
// no need to consider the value of args[2] // no need to consider the value of args[2]
...@@ -410,6 +411,68 @@ struct cpu_gemm ...@@ -410,6 +411,68 @@ struct cpu_gemm
} }
}; };
struct cpu_quant_gemm
{
op::quant_dot op;
std::string name() const { return "cpu::quant_dot"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrices, and C is of the same shape to A * B
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
});
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, arg_0, arg_1, op.alpha, op.beta);
return result;
}
// 2 input arguments
int8_t beta = 0;
migemm(result, arg_0, arg_1, op.alpha, beta);
return result;
}
};
struct cpu_gather struct cpu_gather
{ {
op::gather op; op::gather op;
...@@ -816,6 +879,7 @@ struct cpu_apply ...@@ -816,6 +879,7 @@ struct cpu_apply
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
......
...@@ -49,6 +49,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s) ...@@ -49,6 +49,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
d = miopenFloat; d = miopenFloat;
else if(s.type() == shape::half_type) else if(s.type() == shape::half_type)
d = miopenHalf; d = miopenHalf;
else if(s.type() == shape::int8_type)
d = miopenInt8;
else else
MIGRAPHX_THROW("Unsupported type"); MIGRAPHX_THROW("Unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment