"vscode:/vscode.git/clone" did not exist on "be44758c1e3b1dc1a7c9aadd69bc6a068d7f40ef"
Commit 93c632cb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in the gpu implementation of the quant_convolution.

parent 585bb331
...@@ -34,9 +34,10 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -34,9 +34,10 @@ argument miopen_quant_convolution::compute(context& ctx,
arg_vec4_x.implicit()); arg_vec4_x.implicit());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensfor failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
} }
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(), status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha, &alpha,
w_desc.get(), w_desc.get(),
...@@ -46,15 +47,15 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -46,15 +47,15 @@ argument miopen_quant_convolution::compute(context& ctx,
arg_vec4_w.implicit()); arg_vec4_w.implicit());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensfor failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
} }
status = miopenConvolutionForward(ctx.get_stream().get_miopen(), status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc_vec4.get(),
arg_vec4_x.implicit(), arg_vec4_x.implicit(),
w_desc.get(), w_desc_vec4.get(),
args[1].implicit(), arg_vec4_w.implicit(),
cd.get(), cd.get(),
algo, algo,
&beta, &beta,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment