Commit 931fe619 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_broadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_broadcast

parents 34d14edd 775dfe8a
......@@ -70,7 +70,8 @@ struct broadcast
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
......@@ -107,21 +108,28 @@ struct broadcast
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range");
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length");
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW(
"BROADCAST_2in: s0 length doesn't match with static s1 axis dimension length");
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment