Commit 87fc2260 authored by charlie's avatar charlie
Browse files

initial

parent 57884353
......@@ -33,18 +33,24 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
/**
* 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
* with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D fixed shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -55,33 +61,54 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0);
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
if(inputs.size() == 1)
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
else
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
MIGRAPHX_THROW("BROADCAST_2in: s0 is a static shape, does not handle broadcasting "
"a static shape");
if(s0.ndim() != 1)
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
if(axis > s1.ndim())
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range");
if(s1.ndim() - axis < s0.ndim())
MIGRAPHX_THROW("BROADCAST_2in: (s1_ndim - axis) is less than s0 ndim");
if(s1.dynamic())
return s1;
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW(
"BROADCAST_2in: output size must be greater than or equal to s0 size");
return output;
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment