Commit b72f2b9e authored by umangyadav's avatar umangyadav
Browse files

changes after find2.0

parent 5a14c0bf
......@@ -190,7 +190,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto miopen_conv_op = ins->get_operator().to_value();
auto algo = miopen_conv_op.at("algo").to<miopenConvFwdAlgorithm_t>();
auto algo = miopen_conv_op.at("algo").to<miopenConvAlgorithm_t>();
auto conv_op = from_value<op::convolution>(miopen_conv_op["op"]);
if(conv_op.group > 1)
return false;
......
......@@ -238,35 +238,24 @@ struct miopen_convolution
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + ": get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
size_t solution_count = 1;
miopenConvSolution_t conv_solution;
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
1,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
&conv_solution);
if(status != miopenStatusSuccess or solution_count != 1)
MIGRAPHX_THROW("MIOpen " + op.name() + ": get solution failed");
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}};
solution_id = conv_solution.solution_id;
algo = conv_solution.algorithm;
return shape{shape::int8_type, {conv_solution.workspace_size}};
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment