Commit e4e4040e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Work in progress. can get the correct size literal but not sure if buffer is...

Work in progress. can get the correct size literal but not sure if buffer is being alloced correctly.
parent 2b182551
......@@ -51,12 +51,12 @@ struct if_op
auto out_shapes0 = mods[0]->get_output_shapes();
auto out_shapes1 = mods[1]->get_output_shapes();
if(not std::equal(
/*if(not std::equal(
out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end()))
{
MIGRAPHX_THROW("IF:" + mods[0]->name() + " & " + mods[1]->name() +
" output shapes of submodules must be the same.");
}
" output types of submodules must be the same.");
}*/
return {out_shapes0};
}
......
......@@ -80,32 +80,34 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
// unsqueeze up to a 1d vector for now
// allocate buffer
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{
auto ins = std::prev(then_mdl->end());
auto reshape_ins = then_mdl->insert_instruction(
ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
auto ins = std::prev(then_mdl->end());
auto l = migraphx::literal{else_out_shapes.at(0), else_out_shapes.at(0).lens()};
auto new_lit = then_mdl->insert_literal(ins, l);
then_mdl->replace_return({reshape_ins});
then_mdl->replace_return({new_lit});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
auto ins = std::prev(else_mdl->end());
auto reshape_ins = then_mdl->insert_instruction(
ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
auto ins = std::prev(else_mdl->end());
auto l =
migraphx::literal{else_out_shapes.at(0).type(), then_out_shapes.at(0).lens()};
auto new_lit = else_mdl->insert_literal(ins, l);
else_mdl->replace_return({reshape_ins});
else_mdl->replace_return({new_lit});
}
// First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
else
{
MIGRAPHX_THROW("PARSE_IF: " + then_out_shapes.at(0).type_string() + " & " +
else_out_shapes.at(0).type_string() +
" are incompatible output shapes for then/cases");
// First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + then_out_shapes.at(0).type_string() + " & " +
else_out_shapes.at(0).type_string() +
" are incompatible output shapes for then/cases");
}
}
auto then_out_strides = then_out_shapes.at(0).strides();
auto else_out_strides = else_out_shapes.at(0).strides();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment