Commit 980ca67d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the quant_gemm operator for int8 type support in dot operator

parent 849f7d92
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_dot
{
int8_t alpha = 1;
int8_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if (t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if (inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -40,6 +40,7 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
......
......@@ -44,12 +44,12 @@ struct is_fast_gemm_type<float> : std::true_type
{
};
template <class T>
template <class T, class F>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
F alpha,
F beta,
std::true_type)
{
visit_mat(amat, [&](const auto& a) {
......@@ -66,12 +66,12 @@ void migemm_impl(tensor_view<T> cmat,
});
}
template <class T>
template <class T, class F>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
F alpha,
F beta,
std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
......@@ -95,9 +95,9 @@ void migemm_impl(tensor_view<T> cmat,
});
}
template <class T>
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
......@@ -113,8 +113,9 @@ void migemm_impl(
}
}
template<class F>
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
......
......@@ -8,8 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template<class T>
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -7,6 +7,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/cpu/quant_gemm.hpp>
#include <unordered_map>
#include <utility>
......@@ -383,7 +384,7 @@ struct cpu_gemm
{
argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
// A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
......@@ -410,6 +411,68 @@ struct cpu_gemm
}
};
struct cpu_quant_gemm
{
op::quant_dot op;
std::string name() const { return "cpu::quant_dot"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrices, and C is of the same shape to A * B
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
});
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, arg_0, arg_1, op.alpha, op.beta);
return result;
}
// 2 input arguments
int8_t beta = 0;
migemm(result, arg_0, arg_1, op.alpha, beta);
return result;
}
};
struct cpu_gather
{
op::gather op;
......@@ -816,6 +879,7 @@ struct cpu_apply
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
......
......@@ -49,6 +49,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else if(s.type() == shape::int8_type)
d = miopenInt8;
else
MIGRAPHX_THROW("Unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment