Commit 8320b11e authored by Paul's avatar Paul
Browse files

Adjust base on broadcast

parent c5d87f8f
...@@ -94,8 +94,10 @@ vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector< ...@@ -94,8 +94,10 @@ vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global; std::size_t over = n / max_global;
std::cout << "over: " << over << std::endl;
bool broadcasted = std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); });
std::vector<std::size_t> sizes; std::vector<std::size_t> sizes;
if(over > 8) if(broadcasted and over > 8)
sizes.push_back(8); sizes.push_back(8);
if(over > 4) if(over > 4)
sizes.push_back(4); sizes.push_back(4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment