Commit 585bb331 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 984fd19c
...@@ -232,9 +232,9 @@ struct cpu_quant_convolution ...@@ -232,9 +232,9 @@ struct cpu_quant_convolution
auto wei_w = wei[3]; auto wei_w = wei[3];
par_dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
...@@ -254,7 +254,7 @@ struct cpu_quant_convolution ...@@ -254,7 +254,7 @@ struct cpu_quant_convolution
}); });
}); });
}); });
return result; return result;
} }
}; };
......
...@@ -49,21 +49,21 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false ...@@ -49,21 +49,21 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false
d = miopenFloat; d = miopenFloat;
else if(s.type() == shape::half_type) else if(s.type() == shape::half_type)
d = miopenHalf; d = miopenHalf;
else if(s.type() == shape::int8_type) else if(s.type() == shape::int8_type)
{ {
if (pack) if(pack)
{ {
// update the lens and corresponding strides // update the lens and corresponding strides
d = miopenInt8x4; d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4; lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1]; strides[0] = strides[1] * lens[1];
} }
else else
{ {
d = miopenInt8; d = miopenInt8;
} }
} }
else else
{ {
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
} }
......
...@@ -15,54 +15,54 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -15,54 +15,54 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true); auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape()); auto w_desc = make_tensor(args[1].get_shape());
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true); auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
// pack input to vec4 format // pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(), auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[0].implicit(), args[0].implicit(),
&beta, &beta,
x_desc_vec4.get(), x_desc_vec4.get(),
arg_vec4_x.implicit()); arg_vec4_x.implicit());
if (status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensfor failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensfor failed");
} }
status = miopenTransformTensor(ctx.get_stream().get_miopen(), status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha, &alpha,
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
&beta, &beta,
w_desc_vec4.get(), w_desc_vec4.get(),
arg_vec4_w.implicit()); arg_vec4_w.implicit());
if (status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensfor failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensfor failed");
} }
status = miopenConvolutionForward(ctx.get_stream().get_miopen(), status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
arg_vec4_x.implicit(), arg_vec4_x.implicit(),
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
cd.get(), cd.get(),
algo, algo,
&beta, &beta,
y_desc.get(), y_desc.get(),
args[3].implicit(), args[3].implicit(),
args[2].implicit(), args[2].implicit(),
args[2].get_shape().bytes()); args[2].get_shape().bytes());
if (status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
} }
...@@ -132,15 +132,15 @@ void miopen_quant_convolution::finalize(context& ctx, ...@@ -132,15 +132,15 @@ void miopen_quant_convolution::finalize(context& ctx,
shape miopen_quant_convolution::pack_int8_shape(shape& s) shape miopen_quant_convolution::pack_int8_shape(shape& s)
{ {
if (s.type() != shape::int8_type) if(s.type() != shape::int8_type)
{ {
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type"); MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
} }
auto lens = s.lens(); auto lens = s.lens();
auto strides = s.strides(); auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4; lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1]; strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides}; return {s.type(), lens, strides};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment