Commit cb555646 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from int8_quantize

parents 12ccb601 4a10535c
......@@ -215,6 +215,10 @@ argument miopen_gemm::compute(context& ctx,
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if(num_matrices == 1)
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......
......@@ -3,8 +3,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -12,7 +10,7 @@ namespace gpu {
struct context;
struct hip_convert : unary_device<hip_convert, device::convert>
struct hip_convert
{
op::convert op;
......@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert>
return migraphx::reflect(self.op, f);
}
hip_convert(op::convert oper) : op(oper) {}
std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
return shapes.size() - 1;
}
};
......
......@@ -11,7 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream, const argument& result, const argument& arg);
void convert(hipStream_t stream,
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type);
} // namespace device
} // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg);
void pack_b(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -34,11 +34,11 @@ Result make_obj(F f, Ts... xs)
auto status = f(&x, xs...);
Result r{x};
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen call failed");
MIGRAPHX_THROW("MAKE_OBJ: MIOpen call failed");
return r;
}
inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false)
{
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
// Convert to ints
......@@ -49,13 +49,31 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else if(s.type() == shape::int8_type)
{
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else
MIGRAPHX_THROW("Unsupported type");
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t;
}
inline convolution_descriptor make_conv(const migraphx::op::convolution& op)
template <class T>
inline convolution_descriptor make_conv(const T& op)
{
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenConvolutionMode_t c_mode = miopenConvolution;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_quant_convolution
{
op::quant_convolution op;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
argument arg_vec4_x{};
argument arg_vec4_w{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::quant_convolution::reflect(self.op, f);
}
std::string name() const { return "gpu::quant_convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
shape pack_int8_shape(shape& s);
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/quant_dot.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_quant_gemm
{
op::quant_dot op;
mutable argument arg_a{};
mutable argument arg_b{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::quant_gemm"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -14,6 +14,7 @@
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
......@@ -41,6 +42,7 @@
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
......@@ -97,6 +99,7 @@ struct miopen_apply
add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
......@@ -108,6 +111,7 @@ struct miopen_apply
add_lrn_op();
add_convolution_op();
add_quant_convolution_op();
add_pooling_op();
add_batch_norm_inference_op();
}
......@@ -154,6 +158,22 @@ struct miopen_apply
});
}
void add_quant_convolution_op()
{
apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
void add_pooling_op()
{
apply_map.emplace("pooling", [=](instruction_ref ins) {
......
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
}
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
shape miopen_quant_convolution::compile(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true);
auto w_desc = make_tensor(inputs[1], true);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
arg_vec4_w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
}
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(shape& s)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs);
input_shapes.pop_back();
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
}
argument miopen_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
if(!transb)
{
if(arg_b.empty())
{
arg_b = allocate_gpu(args[1].get_shape());
}
device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
if(arg_a.empty())
{
arg_a = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
}
bool is_3inputs = (args.size() == 4);
int32_t beta = 0;
if(is_3inputs)
{
beta = op.beta;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0);
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
else
{
generic_rocblas_batched_gemm_ex(
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
k * n,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
num_matrices,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
});
return is_3inputs ? args.at(3) : args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE(quant_dot_2args_multi4)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {112, 118, 124, 130, 136, 142, 148, 154, 304, 326, 348,
370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686,
724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704,
736, 768, 592, 628, 664, 700, 736, 772, 808, 844};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822,
974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708,
836, 964, 74, 218, 362, 506, 650, 794, 938, 1082};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_2args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {
210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(
migraphx::op::quant_dot{
2,
},
l1,
tl2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_batch)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}};
std::vector<int8_t> data1(4 * 2 * 4);
std::vector<int8_t> data2(4 * 4 * 7);
std::vector<int> data3(4 * 2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}};
std::vector<int8_t> data1(48);
std::vector<int8_t> data2(96);
std::vector<int> data3(72);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1338,6 +1338,177 @@ TEST_CASE(conv2d_padding_stride_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_default_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
10197, 10548, 6939, 3420, 11601, 11952, 7839, 3852, 7383, 7590, 4953, 2421, 3480,
3570, 2316, 1125, 25506, 26586, 17874, 9009, 29826, 30906, 20718, 10413, 20505, 21198,
14187, 7119, 10527, 10860, 7257, 3636, 27045, 27396, 17739, 8604, 28449, 28800, 18639,
9036, 17319, 17526, 11289, 5445, 7800, 7890, 5052, 2421, 77346, 78426, 52002, 25857,
81666, 82746, 54846, 27261, 53769, 54462, 36075, 17919, 26511, 26844, 17769, 8820};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_test_valid_mode)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893,
51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> s = {4521,
7014,
7830,
11952,
10515,
16734,
19737,
30906,
13161,
19542,
19494,
28800,
34707,
52590,
54729,
82746};
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(transpose_test)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
......
......@@ -82,6 +82,10 @@ auto get_hash(const T& x)
return std::hash<T>{}(x);
}
// add an overload function for int type
// to avoid overflow in test examples
inline auto get_hash(const int& x) { return std::hash<int>{}(x) / 64; }
void compile_check(migraphx::program& p, const migraphx::target& t)
{
auto name = t.name();
......@@ -1238,6 +1242,115 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
}
};
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
return p;
}
};
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
return p;
}
};
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
return p;
}
};
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
std::vector<int> m3_data(2 * 7, 1);
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
......@@ -1367,6 +1480,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
}
};
struct quant_conv : verify_program<quant_conv>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p;
}
};
struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p;
}
};
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa,
pc);
return p;
}
};
struct quant_conv_padding : verify_program<quant_conv_padding>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p;
}
};
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
......
......@@ -76,6 +76,36 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::op::convolution{}, input2, weights);
}
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::op::quant_convolution{}, input, weights);
throws_shape(migraphx::op::quant_convolution{}, input);
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input2, weights2);
throws_shape(migraphx::op::quant_convolution{}, input2, weights);
migraphx::shape input3{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input3, weights);
throws_shape(migraphx::op::quant_convolution{}, input, weight3);
throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
migraphx::shape output_same_mode{migraphx::shape::float_type, {4, 4, 3, 3}};
expect_shape(output_same_mode,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
input,
weights);
expect_shape(output,
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
input,
weights);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......@@ -584,6 +614,61 @@ TEST_CASE(gemm)
}
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::op::quant_dot{1, 0},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
}
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
}
}
TEST_CASE(rnn)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment