Commit 90cfe474 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refactor code to make packing of int8 gemm and convert as a separate operator

parent af656671
......@@ -37,7 +37,7 @@ add_library(migraphx_device
device/pad.cpp
device/gather.cpp
device/sub.cpp
device/pack.cpp
device/int8_gemm_pack.cpp
device/div.cpp
device/clip.cpp
device/reduce_sum.cpp
......@@ -87,6 +87,8 @@ add_library(migraphx_gpu
clip.cpp
reduce_sum.cpp
reduce_mean.cpp
int8_gemm_pack.cpp
int8_conv_pack.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
......@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens();
......@@ -38,7 +38,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
});
}
void pack_b(hipStream_t stream, const argument& result, const argument& arg)
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens();
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
......@@ -10,9 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg);
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg);
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg);
void pack_b(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_int8_conv_pack
{
std::string name() const { return "gpu::int8_conv_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_int8_gemm_pack_a
{
std::string name() const { return "gpu::int8_gemm_pack_a"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_int8_gemm_pack_b
{
std::string name() const { return "gpu::int8_gemm_pack_b"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -17,8 +17,6 @@ struct miopen_quant_convolution
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
argument arg_vec4_x{};
argument arg_vec4_w{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -38,8 +36,7 @@ struct miopen_quant_convolution
return shapes.size() - 1;
}
private:
shape pack_int8_shape(shape& s);
shape pack_int8_shape(const shape& s);
};
} // namespace gpu
......
......@@ -10,7 +10,7 @@ namespace gpu {
struct context;
struct miopen_quant_gemm
struct rocblas_quant_gemm
{
op::quant_dot op;
......
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
}
argument miopen_int8_conv_pack::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument hip_int8_gemm_pack_a::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument hip_int8_gemm_pack_b::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -59,6 +59,7 @@
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -115,6 +116,7 @@ struct miopen_apply
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax");
......@@ -131,7 +133,7 @@ struct miopen_apply
add_lrn_op();
add_convolution_op();
add_quant_convolution_op();
add_quant_dot_op();
//add_quant_dot_op();
add_pooling_op();
add_batch_norm_inference_op();
}
......@@ -182,32 +184,21 @@ struct miopen_apply
{
apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
auto args = ins->inputs();
auto arg_x_vec4 = insert_allocation(ins, conv.pack_int8_shape(args[0]->get_shape()));
auto arg_x_packed = prog->insert_instruction(ins, miopen_int8_conv_pack{}, {args[0], arg_x_vec4});
return prog->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
auto arg_y_vec4 = insert_allocation(ins, conv.pack_int8_shape(args[1]->get_shape()));
auto arg_y_packed = prog->insert_instruction(ins, miopen_int8_conv_pack{}, {args[1], arg_y_vec4});
void add_quant_dot_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto inputs = ins->inputs();
auto in_shapes = to_shapes(inputs);
auto pack_a = insert_allocation(ins, in_shapes[0], "pack_a");
auto pack_b = insert_allocation(ins, in_shapes[1], "pack_b");
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
inputs.push_back(pack_a);
inputs.push_back(pack_b);
inputs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, inputs);
return prog->replace_instruction(
ins, conv, arg_x_packed, arg_y_packed, workspace, output);
});
}
......
......@@ -16,47 +16,19 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto x_desc = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
}
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
......@@ -90,8 +62,8 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
......@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(shape& s)
shape miopen_quant_convolution::pack_int8_shape(const shape& s)
{
if(s.type() != shape::int8_type)
{
......
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
......@@ -52,7 +51,7 @@ rb_type<T>* to_rocblas_type(T* x)
return reinterpret_cast<rb_type<T>*>(x);
}
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end());
......@@ -61,7 +60,7 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(in_shapes);
}
argument miopen_quant_gemm::compute(context& ctx,
argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
......@@ -70,24 +69,11 @@ argument miopen_quant_gemm::compute(context& ctx,
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
if(!transb)
{
device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
}
bool is_3inputs = (arg_num == 6);
bool is_3inputs = (args.size() == 4);
int32_t beta = 0;
if(is_3inputs)
{
......@@ -121,18 +107,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(args[arg_num - 2])
: to_pointer(args.at(1)),
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
to_pointer(args[arg_num - 1]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
......@@ -152,11 +137,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
k * n,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
m * k,
......@@ -165,7 +150,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r,
ldc,
m * n,
to_pointer(args[arg_num - 1]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
......@@ -179,7 +164,7 @@ argument miopen_quant_gemm::compute(context& ctx,
}
});
return args[arg_num - 1];
return is_3inputs ? args[3] : args[2];
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment