Commit 6893dea9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refinement for int8 convolution.

parent 7f3a960b
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
...@@ -352,64 +351,5 @@ void quantize_int8(program& prog) ...@@ -352,64 +351,5 @@ void quantize_int8(program& prog)
quantize_int8(prog, ins_names, int8_quant_params); quantize_int8(prog, ins_names, int8_quant_params);
} }
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
std::function<void(std::size_t, std::vector<argument>)> func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution", "quant_dot", "quant_convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
// set one pair of parameter for each argument
int8_quant_params.resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
}
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
capture_arguments(prog, ins_names, calc_quant_params);
}
void capture_arguments(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #include <migraphx/gpu/quant_convolution.hpp>
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <array> #include <migraphx/generate.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace gpu {
struct convert : unary<convert> shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{ {
shape::type_t target_type = shape::half_type; check_shapes{inputs, *this}.has(5).standard();
float scale = 1.0f; return op.compute_shape({inputs.at(0), inputs.at(1)});
float shift = 0.0f; }
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
shape tmp_output_shape{shape::float_type, output_shape.lens()};
auto y_desc = make_tensor(tmp_output_shape);
float alpha = 1;
float beta = 0;
template <class Self, class F> // pack input to vec4 format
static auto reflect(Self& self, F f) auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{ {
return pack( MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
} }
shape compute_shape(std::vector<shape> inputs) const // pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{ {
check_shapes{inputs, *this}.has(1); MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
} }
auto apply() const status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{ {
return [&](auto x) { MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
float res = scale * x + shift; }
if(target_type == shape::int8_type)
// Add a conversion from float to int32_t
device::convert(ctx.get_stream().get(), args[4], args[3], 1.0f, 0.0f, shape::int32_type);
return args[4];
}
shape miopen_quant_convolution::compile(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true);
auto w_desc = make_tensor(inputs[1], true);
shape tmp_output_shape{shape::float_type, output_shape.lens()};
auto y_desc = make_tensor(tmp_output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(tmp_output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
arg_vec4_w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
{ {
int factor = (res >= 0.0f) ? 1 : -1; MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
res = res + factor * 0.5f;
res = res > 127.0f ? 127.0f : res;
res = res < -128.0f ? -128.0f : res;
} }
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
return res; shape miopen_quant_convolution::pack_int8_shape(shape& s)
}; {
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
} }
convert(shape::type_t t) : target_type{t} {} auto lens = s.lens();
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {} auto strides = s.strides();
convert() {} lens[1] = (lens[1] + 3) / 4 * 4;
}; strides[0] = strides[1] * lens[1];
} // namespace op return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment