Commit f20c1124 authored by Paul's avatar Paul
Browse files

IMprove preloader

parent 817543c7
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -78,25 +79,26 @@ std::string vectorize::str() const ...@@ -78,25 +79,26 @@ std::string vectorize::str() const
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result(inputs.size());
std::transform(inputs.begin(), std::vector<std::size_t> preloaded;
inputs.end(), for(auto i:range(inputs.size()))
std::back_inserter(result), {
[&](const shape& input) { return input.strides()[axis] == 0; }); if (inputs[i].strides()[axis] == 0)
auto bytes = std::inner_product(inputs.begin(), preloaded.push_back(i);
inputs.end(), }
result.begin(), std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
std::size_t{0}, return inputs[i].bytes();
std::plus<>{}, }));
[](const shape& s, bool b) -> std::size_t {
if(b) std::size_t bytes = 0;
return s.bytes(); for(auto i:preloaded)
return 0; {
}); auto input = inputs[i];
if(bytes < max_lds_bytes) bytes += input.bytes();
return {result}; if (bytes > max_lds_bytes)
// TODO: Try to partially preload items break;
std::fill(result.begin(), result.end(), false); result[i] = true;
}
return {result}; return {result};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment