Unverified Commit e4dc75ea authored by Lakhinder Walia's avatar Lakhinder Walia Committed by GitHub
Browse files

Lw/fix half shape (#2000)

* Use shape of Instruction (instead of a default) in add_return()

* Instruction validation fix: not to use a default shape value for comparison

* Fix instruction::replace() to recompute shape for "@return"

* handle the case of missing shape in an Instruction related Test

* use compute_shape() to get op shapes + test case for tuple_type

* add test case shape_test/return_shape_tuple

* Add test for @return to check for half type

* Move @return unit-tests around..; Address review comments

* Broken comparison fix: comparison to a (default) shape of tuple_type

* Test cases: (add) return_shape_empty & (modify) return_shape_tuple

* modify the assert() statement
parent 5d91adcf
......@@ -90,7 +90,17 @@ struct param
struct returns
{
std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("builtin");
......
......@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result = r;
for(auto&& ins : output)
{
if(ins->name() == "@return")
continue;
assert(ins->name().front() != '@');
assert(ins->name() == "@return" or ins->name().front() != '@');
ins->recompute_shape();
}
}
......@@ -122,10 +119,6 @@ bool instruction::valid() const
{
computed = result;
}
else if(op.name() == "@return")
{
computed = {};
}
else
{
try
......@@ -145,6 +138,7 @@ bool instruction::valid() const
}
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
assert(op.name() == "@literal");
......
......@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args)
{
impl->push_back({builtin::returns{}, {}, std::move(args)});
shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
......
......@@ -323,7 +323,7 @@ TEST_CASE(conv_dyn_batch)
TEST_CASE(conv_dyn_img)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {3, 18}, {3, 18}}};
......@@ -376,7 +376,7 @@ TEST_CASE(conv_autopad_dyn_batch)
{
// auto_pad dynamic batch
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {3, 3}, {5, 5}, {5, 5}}};
{{1, 10}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {1, 1}, {5, 5}, {5, 5}}};
......@@ -393,7 +393,7 @@ TEST_CASE(conv_autopad_dyn_img)
{
// auto_pad dynamic img
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 10}, {5, 10}}};
{{1, 1}, {3, 3}, {5, 10}, {5, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {5, 10}, {5, 10}}};
......@@ -2597,6 +2597,36 @@ TEST_CASE(reshape_non_fixed_not_matching_error)
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(return_shape_tuple)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
shape s0{shape::bool_type, {1, 1}};
shape s1{shape::float_type, {2, 3}};
std::vector<shape> s{s0, s1};
auto s_out = op.compute_shape(s);
EXPECT(s_out.type() == shape::tuple_type);
EXPECT(s0 == s_out.sub_shapes()[0]);
EXPECT(s1 == s_out.sub_shapes()[1]);
}
TEST_CASE(return_shape_half)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
std::vector<shape> s{{shape::half_type}};
EXPECT(op.compute_shape(s) == shape{shape::half_type});
}
TEST_CASE(return_shape_empty)
{
using migraphx::shape;
auto op = migraphx::make_op("@return");
std::vector<shape> s;
EXPECT(op.compute_shape(s) == shape{});
}
TEST_CASE(rnn)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment