Commit 5d821f5e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the broadcasted shape to the reflect method

parent 853085f1
......@@ -27,14 +27,15 @@ namespace op {
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
return pack(f(self.axis, "axis"),
f(self.broadcast_lens, "dims"));
}
std::vector<std::size_t> broadcast_lens;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment