Unverified Commit e4dc75ea authored by Lakhinder Walia's avatar Lakhinder Walia Committed by GitHub
Browse files

Lw/fix half shape (#2000)

* Use shape of Instruction (instead of a default) in add_return()

* Instruction validation fix: not to use a default shape value for comparison

* Fix instruction::replace() to recompute shape for "@return"

* handle the case of missing shape in an Instruction related Test

* use compute_shape() to get op shapes + test case for tuple_type

* add test case shape_test/return_shape_tuple

* Add test for @return to check for half type

* Move @return unit-tests around..; Address review comments

* Broken comparison fix: comparison to a (default) shape of tuple_type

* Test cases: (add) return_shape_empty & (modify) return_shape_tuple

* modify the assert() statement
parent 5d91adcf
...@@ -90,7 +90,17 @@ struct param ...@@ -90,7 +90,17 @@ struct param
struct returns struct returns
{ {
std::string name() const { return "@return"; } std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r) ...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return") assert(ins->name() == "@return" or ins->name().front() != '@');
continue;
assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
} }
...@@ -122,10 +119,6 @@ bool instruction::valid() const ...@@ -122,10 +119,6 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@return")
{
computed = {};
}
else else
{ {
try try
...@@ -145,6 +138,7 @@ bool instruction::valid() const ...@@ -145,6 +138,7 @@ bool instruction::valid() const
} }
shape instruction::get_shape() const { return result; } shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const const literal& instruction::get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
......
...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->push_back({builtin::returns{}, {}, std::move(args)}); shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
......
...@@ -323,7 +323,7 @@ TEST_CASE(conv_dyn_batch) ...@@ -323,7 +323,7 @@ TEST_CASE(conv_dyn_batch)
TEST_CASE(conv_dyn_img) TEST_CASE(conv_dyn_img)
{ {
migraphx::shape input_dyn_shape = {migraphx::shape::float_type, migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}}; {{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}}; migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type, migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {3, 18}, {3, 18}}}; {{1, 1}, {1, 1}, {3, 18}, {3, 18}}};
...@@ -376,7 +376,7 @@ TEST_CASE(conv_autopad_dyn_batch) ...@@ -376,7 +376,7 @@ TEST_CASE(conv_autopad_dyn_batch)
{ {
// auto_pad dynamic batch // auto_pad dynamic batch
migraphx::shape input_dyn_shape = {migraphx::shape::float_type, migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {3, 3}, {5, 5}, {5, 5}}}; {{1, 10}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}}; migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type, migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {1, 1}, {5, 5}, {5, 5}}}; {{1, 10}, {1, 1}, {5, 5}, {5, 5}}};
...@@ -393,7 +393,7 @@ TEST_CASE(conv_autopad_dyn_img) ...@@ -393,7 +393,7 @@ TEST_CASE(conv_autopad_dyn_img)
{ {
// auto_pad dynamic img // auto_pad dynamic img
migraphx::shape input_dyn_shape = {migraphx::shape::float_type, migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 10}, {5, 10}}}; {{1, 1}, {3, 3}, {5, 10}, {5, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}}; migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type, migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {5, 10}, {5, 10}}}; {{1, 1}, {1, 1}, {5, 10}, {5, 10}}};
...@@ -2597,6 +2597,36 @@ TEST_CASE(reshape_non_fixed_not_matching_error) ...@@ -2597,6 +2597,36 @@ TEST_CASE(reshape_non_fixed_not_matching_error)
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
} }
TEST_CASE(return_shape_tuple)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
shape s0{shape::bool_type, {1, 1}};
shape s1{shape::float_type, {2, 3}};
std::vector<shape> s{s0, s1};
auto s_out = op.compute_shape(s);
EXPECT(s_out.type() == shape::tuple_type);
EXPECT(s0 == s_out.sub_shapes()[0]);
EXPECT(s1 == s_out.sub_shapes()[1]);
}
TEST_CASE(return_shape_half)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
std::vector<shape> s{{shape::half_type}};
EXPECT(op.compute_shape(s) == shape{shape::half_type});
}
TEST_CASE(return_shape_empty)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
std::vector<shape> s;
EXPECT(op.compute_shape(s) == shape{});
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment