#ifndef MIGRAPHX_GUARD_OPERATORS_MULTIBROADCAST_HPP #define MIGRAPHX_GUARD_OPERATORS_MULTIBROADCAST_HPP #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct multibroadcast { std::vector output_lens; template static auto reflect(Self& self, F f) { return pack(f(self.output_lens, "output_lens")); } std::string name() const { return "multibroadcast"; } shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1); auto t = inputs.at(0).type(); auto input = inputs.at(0); if(input.lens().empty()) MIGRAPHX_THROW("inputs dimensions should be > 0"); if(input.lens().size() > output_lens.size()) MIGRAPHX_THROW("inputs dimensions should <= output size"); std::vector bcast_strides(output_lens.size(), 0); auto offset = output_lens.size() - input.lens().size(); for(int i = input.lens().size() - 1; i >= 0; i--) { if(output_lens[i + offset] == input.lens()[i]) { bcast_strides[i + offset] = input.strides()[i]; } } return {t, output_lens, bcast_strides}; } argument compute(shape output_shape, std::vector args) const { return {std::move(output_shape), std::move(args.at(0).data)}; } int output_alias(const std::vector&) const { return 0; } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif