Commit 1b4216ca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'op_capture' into int8_quantize

parents 0e7e27cc eab3cafb
...@@ -16,47 +16,19 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -16,47 +16,19 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape(), true);
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true); auto w_desc = make_tensor(args[1].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
// pack input to vec4 format auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[0].implicit(), args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
}
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
cd.get(), cd.get(),
algo, algo,
&beta, &beta,
...@@ -90,8 +62,8 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -90,8 +62,8 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0]))); auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1]))); auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
...@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx, ...@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization."); MIGRAPHX_THROW("Workspace has changed during finalization.");
} }
shape miopen_quant_convolution::pack_int8_shape(shape& s) shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{ {
if(s.type() != shape::int8_type) if(s.type() != shape::int8_type)
{ {
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <fstream> #include <fstream>
...@@ -54,16 +53,33 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -54,16 +53,33 @@ rb_type<T>* to_rocblas_type(T* x)
return reinterpret_cast<rb_type<T>*>(x); return reinterpret_cast<rb_type<T>*>(x);
} }
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end()); in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted(); check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
argument miopen_quant_gemm::compute(context& ctx, void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
...@@ -72,25 +88,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -72,25 +88,11 @@ argument miopen_quant_gemm::compute(context& ctx,
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
if(!transb)
{
device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
}
device::sync_stream(ctx.get_stream().get());
bool is_3inputs = (arg_num == 6); bool is_3inputs = (args.size() == 4);
int32_t beta = 0; int32_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -124,18 +126,17 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -124,18 +126,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) to_pointer(args.at(1)),
: to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
to_pointer(args[arg_num - 1]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
...@@ -155,11 +156,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -155,11 +156,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
...@@ -168,7 +169,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -168,7 +169,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
to_pointer(args[arg_num - 1]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
...@@ -182,7 +183,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -182,7 +183,7 @@ argument miopen_quant_gemm::compute(context& ctx,
} }
}); });
return args[arg_num - 1]; return is_3inputs ? args[3] : args[2];
} }
} // namespace gpu } // namespace gpu
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp> #include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
...@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{}, adjust_allocation{},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -2019,7 +2019,8 @@ TEST_CASE(op_capture) ...@@ -2019,7 +2019,8 @@ TEST_CASE(op_capture)
migraphx::program p; migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements()), d2(s2.elements()); std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f); std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f); std::iota(d2.begin(), d2.end(), 0.0f);
...@@ -2040,7 +2041,8 @@ TEST_CASE(op_capture) ...@@ -2040,7 +2041,8 @@ TEST_CASE(op_capture)
auto cap_res = capture_p.eval({}); auto cap_res = capture_p.eval({});
auto res = p.eval({}); auto res = p.eval({});
std::vector<float> vec, cap_vec; std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
......
...@@ -663,13 +663,13 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -663,13 +663,13 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis>> struct test_softmax : verify_program<test_softmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -677,10 +677,38 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -677,10 +677,38 @@ struct test_softmax : verify_program<test_softmax<Axis>>
} }
}; };
template struct test_softmax<0>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<2>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3>; template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -3601,32 +3629,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3601,32 +3629,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3634,7 +3643,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3634,7 +3643,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
...@@ -3796,10 +3814,8 @@ struct test_convert : verify_program<test_convert> ...@@ -3796,10 +3814,8 @@ struct test_convert : verify_program<test_convert>
migraphx::shape sb{migraphx::shape::float_type, {24, 6}}; migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa); auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb); auto pb = p.add_parameter("b", sb);
auto ia = auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 1.0f}, pa); auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
auto ib =
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 2.0f}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib); p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p; return p;
......
...@@ -204,7 +204,7 @@ TEST_CASE(literal_add) ...@@ -204,7 +204,7 @@ TEST_CASE(literal_add)
TEST_CASE(op_capture) TEST_CASE(op_capture)
{ {
auto test_func = [&](std::size_t ins_index, std::vector<migraphx::argument> args) { auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index; (void)ins_index;
(void)args; (void)args;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment