Commit 3e4991a6 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Handle case to convert shapes for (x1,x2,x3,..xn) == (x1,x2,x3,..,xn)

Still busted work in prgoress. Keep running into

terminate called after throwing an instance of 'migraphx::version_1::exception'
  what():  /code/AMDMIGraphX/src/normalize_attributes.cpp:91: tune_attribute: TUNE_VECTOR: value out of range!
Aborted (core dumped)
parent 9199f074
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -80,79 +82,71 @@ struct parse_if : op_parser<parse_if> ...@@ -80,79 +82,71 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{ {
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar()) auto then_shape = then_out_shapes.at(0).lens();
auto else_shape = else_out_shapes.at(0).lens();
int dim_delta = abs((static_cast<int>(then_shape.size() - else_shape.size())));
auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes ");
};
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
if(dim_delta <= 1)
{ {
auto convert_ins = std::prev(then_mdl->end()); // make sure dims are equivalent in static shapes
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() && if(not equal(then_shape.begin(), then_shape.end(), else_shape.begin()) &&
then_out_shapes.at(0).elements() < 1) not equal(else_shape.begin(), else_shape.end(), then_shape.begin()))
{ {
convert_ins = then_mdl->insert_instruction( throw_shapes();
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
} }
migraphx::shape s{else_out_shapes.at(0).type(), // find bigger dimension and pad if its 1 otherwise throw
else_out_shapes.at(0).lens(), if(dim_delta == 1)
else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {1}}}), convert_ins);
then_mdl->replace_return({reshape_ins});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
auto convert_ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{ {
convert_ins = then_mdl->insert_instruction( bool invalid_last_dim = true;
std::prev(then_mdl->end()),
migraphx::make_op("convert", // Find which dim to pad
{{"target_type", then_out_shapes.at(0).type()}}), if(then_shape.size() < else_shape.size())
std::prev(then_mdl->end())->inputs().front()); {
then_mdl->replace_return({convert_ins}); auto last_else = *(--(else_shape.end()));
if(last_else == 1)
{
invalid_last_dim = false;
// migraphx::shape s{else_out_shapes.at(0).type(), {1,1,1,1}};
// else_out_shapes.at(0) = reduce_dims({else_out_shapes, s});
auto convert_ins = else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op("squeeze", {{"axes", {else_shape.size()}}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({convert_ins});
}
}
else
{
auto last_then = *(--(then_shape.end()));
if(last_then == 1)
{
invalid_last_dim = false;
// migraphx::shape s{else_out_shapes.at(0).type(), {1,1,1,1}};
// then_out_shapes = reduce_dims({then_out_shapes, s});
auto convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("squeeze", {{"axes", {then_shape.size()}}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
}
if(invalid_last_dim)
{
throw_shapes();
}
} }
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {1}}}), convert_ins);
else_mdl->replace_return({reshape_ins});
}
// First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + then_out_shapes.at(0).type_string() + " & " +
else_out_shapes.at(0).type_string() +
" are incompatible output shapes for then/cases");
} }
auto then_out_strides = then_out_shapes.at(0).strides(); else
auto else_out_strides = else_out_shapes.at(0).strides();
// Generate compatible output types based on largest dimension with rank 1 tensor
if(then_out_strides.size() > else_out_strides.size())
{
auto reshape_ins = else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op("reshape",
{{"dims", {else_out_shapes.at(0).lens().at(0), 1}}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({reshape_ins});
}
else if(then_out_strides.size() < else_out_strides.size())
{ {
auto reshape_ins = then_mdl->insert_instruction( throw_shapes();
std::prev(then_mdl->end()),
migraphx::make_op("reshape",
{{"dims", {then_out_shapes.at(0).lens().at(0), 1}}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({reshape_ins});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment