Commit 7e522c7e authored by Paul's avatar Paul
Browse files

Formatting

parent 33f53196
...@@ -431,7 +431,8 @@ struct broadcast ...@@ -431,7 +431,8 @@ struct broadcast
auto input = inputs.at(1); auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0); std::vector<size_t> bcast_strides(result.lens().size(), 0);
if(std::all_of(result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; })) if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0"); RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
...@@ -439,10 +440,10 @@ struct broadcast ...@@ -439,10 +440,10 @@ struct broadcast
} }
else else
{ {
assert(result.lens().size()-axis >= input.lens().size()); assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin()+axis)) if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
RTG_THROW("when broadcasting success sizes must match"); RTG_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin()+axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
} }
......
...@@ -91,10 +91,10 @@ void broadcast_test() ...@@ -91,10 +91,10 @@ void broadcast_test()
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
EXPECT(output(0,0) == -2); EXPECT(output(0, 0) == -2);
EXPECT(output(0,1) == -2); EXPECT(output(0, 1) == -2);
EXPECT(output(1,0) == -3); EXPECT(output(1, 0) == -3);
EXPECT(output(1,1) == -3); EXPECT(output(1, 1) == -3);
} }
void add_broadcast_test() void add_broadcast_test()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment