"tests/vscode:/vscode.git/clone" did not exist on "1b91856d0eee7b6fb58340e9b54ea2c3d5424311"
Commit 7e522c7e authored by Paul's avatar Paul
Browse files

Formatting

parent 33f53196
...@@ -426,12 +426,13 @@ struct broadcast ...@@ -426,12 +426,13 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto result = inputs.at(0); auto result = inputs.at(0);
auto input = inputs.at(1); auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0); std::vector<size_t> bcast_strides(result.lens().size(), 0);
if(std::all_of(result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; })) if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0"); RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
...@@ -439,10 +440,10 @@ struct broadcast ...@@ -439,10 +440,10 @@ struct broadcast
} }
else else
{ {
assert(result.lens().size()-axis >= input.lens().size()); assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin()+axis)) if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
RTG_THROW("when broadcasting success sizes must match"); RTG_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin()+axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
} }
......
...@@ -91,10 +91,10 @@ void broadcast_test() ...@@ -91,10 +91,10 @@ void broadcast_test()
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
EXPECT(output(0,0) == -2); EXPECT(output(0, 0) == -2);
EXPECT(output(0,1) == -2); EXPECT(output(0, 1) == -2);
EXPECT(output(1,0) == -3); EXPECT(output(1, 0) == -3);
EXPECT(output(1,1) == -3); EXPECT(output(1, 1) == -3);
} }
void add_broadcast_test() void add_broadcast_test()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment