Commit 5d821f5e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the broadcasted shape to the reflect method

parent 853085f1
...@@ -27,14 +27,15 @@ namespace op { ...@@ -27,14 +27,15 @@ namespace op {
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"),
f(self.broadcast_lens, "dims"));
} }
std::vector<std::size_t> broadcast_lens;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment