Commit c56d6e9e authored by Paul's avatar Paul
Browse files

Format

parent f20c1124
...@@ -81,21 +81,21 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) ...@@ -81,21 +81,21 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result(inputs.size()); std::vector<bool> result(inputs.size());
std::vector<std::size_t> preloaded; std::vector<std::size_t> preloaded;
for(auto i:range(inputs.size())) for(auto i : range(inputs.size()))
{ {
if (inputs[i].strides()[axis] == 0) if(inputs[i].strides()[axis] == 0)
preloaded.push_back(i); preloaded.push_back(i);
} }
std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) { std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
return inputs[i].bytes(); return inputs[i].bytes();
})); }));
std::size_t bytes = 0; std::size_t bytes = 0;
for(auto i:preloaded) for(auto i : preloaded)
{ {
auto input = inputs[i]; auto input = inputs[i];
bytes += input.bytes(); bytes += input.bytes();
if (bytes > max_lds_bytes) if(bytes > max_lds_bytes)
break; break;
result[i] = true; result[i] = true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment