Commit 1b4216ca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'op_capture' into int8_quantize

parents 0e7e27cc eab3cafb
...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") ...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif() endif()
endif() endif()
if(CMAKE_CXX_COMPILER MATCHES ".*hcc") include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("--cuda-host-only -x hip" HAS_HIP)
if(HAS_HIP)
message(STATUS "Enable miopen backend") message(STATUS "Enable miopen backend")
set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "")
else() else()
......
...@@ -82,8 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -82,8 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// divide a value to avoid integer overflow
std::transform(result.begin(), result.end(), result.begin(), [](auto i) { return i / 32; });
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; }); // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; }); // std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
......
...@@ -20,14 +20,11 @@ namespace op { ...@@ -20,14 +20,11 @@ namespace op {
struct convert : unary<convert> struct convert : unary<convert>
{ {
shape::type_t target_type = shape::half_type; shape::type_t target_type = shape::half_type;
float scale = 1.0f;
float shift = 0.0f;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack( return pack(f(self.target_type, "target_type"));
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -38,22 +35,10 @@ struct convert : unary<convert> ...@@ -38,22 +35,10 @@ struct convert : unary<convert>
auto apply() const auto apply() const
{ {
return [&](auto x) { return [](auto x) { return x; };
float res = scale * x + shift;
if(target_type == shape::int8_type)
{
int factor = (res >= 0.0f) ? 1 : -1;
res = res + factor * 0.5f;
res = res > 127.0f ? 127.0f : res;
res = res < -128.0f ? -128.0f : res;
}
return res;
};
} }
convert(shape::type_t t) : target_type{t} {} convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
convert() {} convert() {}
}; };
......
...@@ -19,7 +19,7 @@ void quantize(program& prog); ...@@ -19,7 +19,7 @@ void quantize(program& prog);
// to int8 // to int8
void capture_arguments(program& prog, void capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
std::function<void(std::size_t, std::vector<argument>)> func); const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names); void capture_arguments(program& prog, const std::vector<std::string>& ins_names);
void capture_arguments(program& prog); void capture_arguments(program& prog);
......
...@@ -23,9 +23,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -23,9 +23,7 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
float scale = 1.0f,
float shift = 0.0f)
{ {
if(map_ins.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
...@@ -37,16 +35,11 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -37,16 +35,11 @@ instruction_ref insert_quant_ins(program& prog,
return ins; return ins;
} }
if(scale < 0.0f)
{
MIGRAPHX_THROW("INSERT_QUANT_INS: scale less than 0");
}
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type || ins->get_shape().type() == shape::double_type ||
ins->get_shape().type() == shape::int32_type); ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{}; instruction_ref quant_ins{};
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type, scale, shift}, ins); quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
return quant_ins; return quant_ins;
...@@ -337,7 +330,7 @@ void quantize_int8(program& prog) ...@@ -337,7 +330,7 @@ void quantize_int8(program& prog)
// capture operator to compute the scale and shift // capture operator to compute the scale and shift
void capture_arguments(program& prog, void capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
std::function<void(std::size_t, std::vector<argument>)> func) const std::function<void(std::size_t, std::vector<argument>)>& func)
{ {
size_t num_quant_params = 0; size_t num_quant_params = 0;
......
...@@ -37,7 +37,7 @@ add_library(migraphx_device ...@@ -37,7 +37,7 @@ add_library(migraphx_device
device/pad.cpp device/pad.cpp
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/pack.cpp device/int8_gemm_pack.cpp
device/div.cpp device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
...@@ -85,9 +85,12 @@ add_library(migraphx_gpu ...@@ -85,9 +85,12 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
pack_int8_args.cpp
clip.cpp clip.cpp
reduce_sum.cpp reduce_sum.cpp
reduce_mean.cpp reduce_mean.cpp
int8_gemm_pack.cpp
int8_conv_pack.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
...@@ -15,7 +15,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const ...@@ -15,7 +15,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift, op.target_type); device::convert(ctx.get_stream().get(), args[1], args[0]);
return args[1]; return args[1];
} }
......
...@@ -6,31 +6,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,31 +6,14 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void convert(hipStream_t stream, void convert(hipStream_t stream, const argument& result, const argument& arg)
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
arg.visit([&](auto input) { arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
if(target_type == shape::int8_type) gs_launch(stream,
{ result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; });
gs_launch(stream, result.get_shape().elements())([=](auto i) {
float res = input_ptr[i] * scale + shift;
int factor = (res >= 0.0f) ? 1 : -1;
output_ptr[i] = static_cast<int8_t>(
std::min<float>(std::max<float>(-128.0f, res + factor * 0.5), 127.0f));
});
}
else
{
gs_launch(stream, result.get_shape().elements())(
[=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
}
}); });
}); });
} }
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/pack.hpp> #include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg) void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto comp_shape = arg.get_shape(); auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens(); auto out_lens = comp_shape.lens();
...@@ -38,7 +38,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -38,7 +38,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
}); });
} }
void pack_b(hipStream_t stream, const argument& result, const argument& arg) void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto trans_shape = arg.get_shape(); auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens(); auto out_lens = trans_shape.lens();
......
...@@ -11,12 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,12 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void convert(hipStream_t stream, void convert(hipStream_t stream, const argument& result, const argument& arg);
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -10,11 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,11 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg); void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg);
void pack_b(hipStream_t stream, const argument& result, const argument& arg); void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg);
void sync_stream(hipStream_t stream);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_int8_conv_pack
{
std::string name() const { return "gpu::int8_conv_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_int8_gemm_pack_a
{
std::string name() const { return "gpu::int8_gemm_pack_a"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_int8_gemm_pack_b
{
std::string name() const { return "gpu::int8_gemm_pack_b"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(program& p) const;
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,8 +17,6 @@ struct miopen_quant_convolution ...@@ -17,8 +17,6 @@ struct miopen_quant_convolution
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; miopenHandle_t handle = nullptr;
argument arg_vec4_x{};
argument arg_vec4_w{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -39,7 +37,7 @@ struct miopen_quant_convolution ...@@ -39,7 +37,7 @@ struct miopen_quant_convolution
} }
private: private:
shape pack_int8_shape(shape& s); shape pack_int8_shape(const shape& s) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ namespace gpu { ...@@ -10,7 +10,7 @@ namespace gpu {
struct context; struct context;
struct miopen_quant_gemm struct rocblas_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
...@@ -24,6 +24,7 @@ struct miopen_quant_gemm ...@@ -24,6 +24,7 @@ struct miopen_quant_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_a::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_b::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp> #include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp> #include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -115,6 +116,7 @@ struct miopen_apply ...@@ -115,6 +116,7 @@ struct miopen_apply
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
...@@ -131,7 +133,7 @@ struct miopen_apply ...@@ -131,7 +133,7 @@ struct miopen_apply
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_quant_dot_op(); // add_quant_dot_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -182,32 +184,14 @@ struct miopen_apply ...@@ -182,32 +184,14 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(ins, conv, args[0], args[1], workspace, output);
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
void add_quant_dot_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto inputs = ins->inputs();
auto in_shapes = to_shapes(inputs);
auto pack_a = insert_allocation(ins, in_shapes[0], "pack_a");
auto pack_b = insert_allocation(ins, in_shapes[1], "pack_b");
auto output = insert_allocation(ins, ins->get_shape());
inputs.push_back(pack_a);
inputs.push_back(pack_b);
inputs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, inputs);
}); });
} }
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() == "gpu::quant_gemm")
{
auto inputs = ins->inputs();
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(!transb)
{
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()});
auto output_b =
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b});
instruction::replace_argument(ins, inputs[1], output_b);
}
if(transa)
{
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()});
auto output_a =
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a});
instruction::replace_argument(ins, inputs[0], output_a);
}
}
else if(ins->name() == "gpu::quant_convolution")
{
auto inputs = ins->inputs();
auto packed_x =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())});
auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())});
auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment