Commit 02faf773 authored by umangyadav's avatar umangyadav
Browse files

only return 1 solution

parent 956f5bc7
......@@ -120,35 +120,25 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed");
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution count failed");
size_t solution_count = 1;
std::vector<miopenConvSolution_t> solutions(solution_count);
miopenConvSolution_t deconv_solution;
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
1,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
&deconv_solution);
if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed");
const auto& best_solution = solutions.front();
solution_id = best_solution.solution_id;
algo = best_solution.algorithm;
solution_id = deconv_solution.solution_id;
algo = deconv_solution.algorithm;
return shape{shape::int8_type, {best_solution.workspace_size}};
return shape{shape::int8_type, {deconv_solution.workspace_size}};
}
void miopen_deconvolution::finalize(context& ctx,
......
......@@ -113,36 +113,25 @@ shape miopen_quant_convolution::find(context& ctx,
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution count failed");
size_t solution_count = 1;
std::vector<miopenConvSolution_t> solutions(solution_count);
miopenConvSolution_t qconv_solution;
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
1,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
&qconv_solution);
if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed");
const auto& best_solution = solutions.front();
solution_id = best_solution.solution_id;
algo = best_solution.algorithm;
solution_id = qconv_solution.solution_id;
algo = qconv_solution.algorithm;
return shape{shape::int8_type, {best_solution.workspace_size}};
return shape{shape::int8_type, {qconv_solution.workspace_size}};
}
void miopen_quant_convolution::finalize(context& ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment