Commit 02faf773 authored by umangyadav's avatar umangyadav
Browse files

only return 1 solution

parent 956f5bc7
...@@ -120,35 +120,25 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v ...@@ -120,35 +120,25 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed");
size_t solution_count; size_t solution_count = 1;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), miopenConvSolution_t deconv_solution;
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(), w_desc.get(),
x_desc.get(), x_desc.get(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
solution_count, 1,
&solution_count, &solution_count,
solutions.data()); &deconv_solution);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed");
const auto& best_solution = solutions.front(); solution_id = deconv_solution.solution_id;
solution_id = best_solution.solution_id; algo = deconv_solution.algorithm;
algo = best_solution.algorithm;
return shape{shape::int8_type, {best_solution.workspace_size}}; return shape{shape::int8_type, {deconv_solution.workspace_size}};
} }
void miopen_deconvolution::finalize(context& ctx, void miopen_deconvolution::finalize(context& ctx,
......
...@@ -113,36 +113,25 @@ shape miopen_quant_convolution::find(context& ctx, ...@@ -113,36 +113,25 @@ shape miopen_quant_convolution::find(context& ctx,
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed"); MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
size_t solution_count; size_t solution_count = 1;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), miopenConvSolution_t qconv_solution;
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(), w_desc.get(),
x_desc.get(), x_desc.get(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
solution_count, 1,
&solution_count, &solution_count,
solutions.data()); &qconv_solution);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed"); MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed");
const auto& best_solution = solutions.front(); solution_id = qconv_solution.solution_id;
algo = qconv_solution.algorithm;
solution_id = best_solution.solution_id;
algo = best_solution.algorithm;
return shape{shape::int8_type, {best_solution.workspace_size}}; return shape{shape::int8_type, {qconv_solution.workspace_size}};
} }
void miopen_quant_convolution::finalize(context& ctx, void miopen_quant_convolution::finalize(context& ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment