Commit 1b4216ca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'op_capture' into int8_quantize

parents 0e7e27cc eab3cafb
......@@ -16,54 +16,26 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
auto x_desc = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
}
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
......@@ -90,10 +62,10 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
......@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(shape& s)
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
......
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
#include <fstream>
......@@ -54,43 +53,46 @@ rb_type<T>* to_rocblas_type(T* x)
return reinterpret_cast<rb_type<T>*>(x);
}
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end());
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes);
}
argument miopen_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0];
if(!transb)
{
device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
}
device::sync_stream(ctx.get_stream().get());
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
bool is_3inputs = (arg_num == 6);
bool is_3inputs = (args.size() == 4);
int32_t beta = 0;
if(is_3inputs)
{
......@@ -124,18 +126,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(args[arg_num - 2])
: to_pointer(args.at(1)),
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
to_pointer(args[arg_num - 1]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
......@@ -155,11 +156,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
k * n,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
m * k,
......@@ -168,7 +169,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r,
ldc,
m * n,
to_pointer(args[arg_num - 1]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
......@@ -182,7 +183,7 @@ argument miopen_quant_gemm::compute(context& ctx,
}
});
return args[arg_num - 1];
return is_3inputs ? args[3] : args[2];
}
} // namespace gpu
......
......@@ -21,6 +21,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
......@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
adjust_allocation{},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
......
......@@ -2019,7 +2019,8 @@ TEST_CASE(op_capture)
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements()), d2(s2.elements());
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
......@@ -2040,7 +2041,8 @@ TEST_CASE(op_capture)
auto cap_res = capture_p.eval({});
auto res = p.eval({});
std::vector<float> vec, cap_vec;
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
......
......@@ -663,13 +663,13 @@ struct test_softmax2 : verify_program<test_softmax2>
}
};
template <int Axis>
struct test_softmax : verify_program<test_softmax<Axis>>
template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
......@@ -677,10 +677,38 @@ struct test_softmax : verify_program<test_softmax<Axis>>
}
};
template struct test_softmax<0>;
template struct test_softmax<1>;
template struct test_softmax<2>;
template struct test_softmax<3>;
template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv>
{
......@@ -3601,32 +3629,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
}
};
template <int Axis>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
......@@ -3634,7 +3643,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
}
};
template struct test_logsoftmax_1<0>;
template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{
......@@ -3796,10 +3814,8 @@ struct test_convert : verify_program<test_convert>
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia =
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 1.0f}, pa);
auto ib =
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 2.0f}, pb);
auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p;
......
......@@ -204,7 +204,7 @@ TEST_CASE(literal_add)
TEST_CASE(op_capture)
{
auto test_func = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment