"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "fd2921b53cdb2a1fa66ceaf8878035fe81367a3a"
Commit 0ff0839d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_broadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents e026d93c ab812826
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <exception> #include <exception>
#include <array>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
......
...@@ -128,8 +128,8 @@ struct broadcast ...@@ -128,8 +128,8 @@ struct broadcast
{ {
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis " MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" + "dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) + migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")"); " != " + migraphx::to_string(s1.lens()[axis]) + ")");
} }
std::vector<size_t> bcast_strides(s1.ndim(), 0); std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
......
...@@ -58,10 +58,10 @@ struct parse_binary_op : op_parser<parse_binary_op> ...@@ -58,10 +58,10 @@ struct parse_binary_op : op_parser<parse_binary_op>
if(broadcasted != 0) if(broadcasted != 0)
{ {
if(std::any_of( if(std::any_of(
args.cbegin(), args.cend(), [](auto a) { a->get_shape().dynamic(); })) args.cbegin(), args.cend(), [](auto a) { return a->get_shape().dynamic(); }))
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"binary op broadcast attribute not supported for dynamic input shapes"); "Binary op broadcast attribute not supported for dynamic input shapes");
} }
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction( auto l = info.add_instruction(
......
...@@ -194,7 +194,7 @@ TEST_CASE(test_shape_ndim_static) ...@@ -194,7 +194,7 @@ TEST_CASE(test_shape_ndim_static)
EXPECT(s1.ndim() == 4); EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, {2, 4, 4, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {2, 4, 4, 1, 3}};
EXPECT(s1.ndim() == 5); EXPECT(s2.ndim() == 5);
} }
TEST_CASE(test_shape_ndim_dyn) TEST_CASE(test_shape_ndim_dyn)
...@@ -207,7 +207,7 @@ TEST_CASE(test_shape_ndim_dyn) ...@@ -207,7 +207,7 @@ TEST_CASE(test_shape_ndim_dyn)
migraphx::shape s2{migraphx::shape::float_type, migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}}; {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s1.ndim() == 5); EXPECT(s2.ndim() == 5);
} }
TEST_CASE(test_shape_non_packed_single_dim) TEST_CASE(test_shape_non_packed_single_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment