Commit 87fc2260 authored by charlie's avatar charlie
Browse files

initial

parent 57884353
...@@ -33,18 +33,24 @@ namespace migraphx { ...@@ -33,18 +33,24 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This /**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are * 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides. * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd * that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is * input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that * work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero. * with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D fixed shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -55,32 +61,53 @@ struct broadcast ...@@ -55,32 +61,53 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto t = s0.type(); auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); if(inputs.size() == 1)
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); // the ONNX broadcast op is deprecated now, so not handling the negative
} // value of axis anymore
if(axis >= broadcast_lens.size())
MIGRAPHX_THROW("BROADCAST : axis is out of range");
if(broadcast_lens.size() - axis < s0.lens().size())
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
if(broadcast_lens.size() - axis < s0.lens().size()) std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
{ std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims"); shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
} }
else
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); // two inputs
} auto s1 = inputs.at(1);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); if(s0.dynamic())
MIGRAPHX_THROW("BROADCAST_2in: s0 is a static shape, does not handle broadcasting "
"a static shape");
if(s0.ndim() != 1)
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
if(axis > s1.ndim())
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range");
if(s1.ndim() - axis < s0.ndim())
MIGRAPHX_THROW("BROADCAST_2in: (s1_ndim - axis) is less than s0 ndim");
shape output{t, broadcast_lens, std::move(bcast_strides)}; if(s1.dynamic())
if(output.elements() < s0.elements()) return s1;
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output; std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW(
"BROADCAST_2in: output size must be greater than or equal to s0 size");
return output;
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment