Unverified Commit d689e2d1 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

unary scalar input processing (#912)

* unary scalar input processing

* remove an unnecessary change

* remove unnecessary blank line
parent ccff6beb
...@@ -41,7 +41,11 @@ struct unary : op_name<Derived> ...@@ -41,7 +41,11 @@ struct unary : op_name<Derived>
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.broadcasted()) if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -1583,4 +1583,13 @@ TEST_CASE(step_test) ...@@ -1583,4 +1583,13 @@ TEST_CASE(step_test)
} }
} }
TEST_CASE(unary_scalar_input)
{
migraphx::shape ss{migraphx::shape::half_type};
expect_shape(ss, migraphx::make_op("sin"), ss);
migraphx::shape s{migraphx::shape::float_type, {1}};
expect_shape(s, migraphx::make_op("sin"), s);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment