Commit 984fd19c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in gpu implementation of the quant_convolution operator

parent 28a644f7
......@@ -45,9 +45,10 @@ struct quant_convolution
const shape& weights = inputs.at(1);
auto t = input.type();
// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_THROW: only accept input of type int8_t");
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
}
t = shape::float_type;
......
......@@ -219,7 +219,8 @@ struct cpu_quant_convolution
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto input, auto weights) {
auto in = input.get_shape().lens();
auto in_h = in[2];
auto in_w = in[3];
......@@ -239,7 +240,7 @@ struct cpu_quant_convolution
const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0;
float acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
......@@ -252,6 +253,8 @@ struct cpu_quant_convolution
output(o, w, i, j) = acc;
});
});
});
return result;
}
};
......
......@@ -34,11 +34,11 @@ Result make_obj(F f, Ts... xs)
auto status = f(&x, xs...);
Result r{x};
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen call failed");
MIGRAPHX_THROW("MAKE_OBJ: MIOpen call failed");
return r;
}
inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false)
{
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
// Convert to ints
......@@ -50,10 +50,25 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
else if(s.type() == shape::half_type)
d = miopenHalf;
else if(s.type() == shape::int8_type)
{
if (pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else
MIGRAPHX_THROW("Unsupported type");
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t;
}
......
......@@ -17,6 +17,8 @@ struct miopen_quant_convolution
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
argument arg_vec4_x{};
argument arg_vec4_w{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -32,6 +34,9 @@ struct miopen_quant_convolution
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
shape pack_int8_shape(shape& s);
};
} // namespace gpu
......
......@@ -16,15 +16,43 @@ argument miopen_quant_convolution::compute(context& ctx,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
int8_t alpha = 1;
int8_t beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if (status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensfor failed");
}
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if (status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensfor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
......@@ -34,7 +62,10 @@ argument miopen_quant_convolution::compute(context& ctx,
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
assert(status == miopenStatusSuccess);
if (status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
......@@ -43,8 +74,8 @@ shape miopen_quant_convolution::compile(context& ctx,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]);
auto w_desc = make_tensor(inputs[1]);
auto x_desc = make_tensor(inputs[0], true);
auto w_desc = make_tensor(inputs[1], true);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
......@@ -56,8 +87,8 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]));
auto w = to_gpu(generate_argument(inputs[1]));
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
......@@ -65,9 +96,9 @@ shape miopen_quant_convolution::compile(context& ctx,
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
arg_vec4_x.implicit(),
w_desc.get(),
w.implicit(),
arg_vec4_x.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
......@@ -78,7 +109,9 @@ shape miopen_quant_convolution::compile(context& ctx,
workspace_size,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find convolution failed");
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
}
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
......@@ -97,6 +130,21 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(shape& s)
{
if (s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment