Commit 585bb331 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 984fd19c
......@@ -232,9 +232,9 @@ struct cpu_quant_convolution
auto wei_w = wei[3];
par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
......@@ -254,7 +254,7 @@ struct cpu_quant_convolution
});
});
});
return result;
}
};
......
......@@ -49,21 +49,21 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false
d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else if(s.type() == shape::int8_type)
else if(s.type() == shape::int8_type)
{
if (pack)
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
d = miopenInt8;
}
}
else
else
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
......
......@@ -15,54 +15,54 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if (status != miopenStatusSuccess)
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
x_desc_vec4.get(),
arg_vec4_x.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensfor failed");
}
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if (status != miopenStatusSuccess)
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensfor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if (status != miopenStatusSuccess)
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
......@@ -132,15 +132,15 @@ void miopen_quant_convolution::finalize(context& ctx,
shape miopen_quant_convolution::pack_int8_shape(shape& s)
{
if (s.type() != shape::int8_type)
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment