Commit b72f2b9e authored by umangyadav's avatar umangyadav
Browse files

changes after find2.0

parent 5a14c0bf
...@@ -190,7 +190,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -190,7 +190,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto miopen_conv_op = ins->get_operator().to_value(); auto miopen_conv_op = ins->get_operator().to_value();
auto algo = miopen_conv_op.at("algo").to<miopenConvFwdAlgorithm_t>(); auto algo = miopen_conv_op.at("algo").to<miopenConvAlgorithm_t>();
auto conv_op = from_value<op::convolution>(miopen_conv_op["op"]); auto conv_op = from_value<op::convolution>(miopen_conv_op["op"]);
if(conv_op.group > 1) if(conv_op.group > 1)
return false; return false;
......
...@@ -238,35 +238,24 @@ struct miopen_convolution ...@@ -238,35 +238,24 @@ struct miopen_convolution
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo;
size_t solution_count; size_t solution_count = 1;
miopenConvSolution_t conv_solution;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + ": get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(), w_desc.get(),
x_desc.get(), x_desc.get(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
solution_count, 1,
&solution_count, &solution_count,
solutions.data()); &conv_solution);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen " + op.name() + ": get solution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + ": get solution failed");
solution_id = solutions.front().solution_id; solution_id = conv_solution.solution_id;
algo = conv_solution.algorithm;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {conv_solution.workspace_size}};
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment