Commit 8a6ae079 authored by Paul's avatar Paul
Browse files

Format

parent d60364a3
...@@ -19,7 +19,7 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -19,7 +19,7 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
} }
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{ {
auto sizes = vector_sizes(inputs); auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(), std::transform(inputs.begin(),
...@@ -37,7 +37,7 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -37,7 +37,7 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return 1; return 1;
}); });
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis}; return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
} }
std::string vectorize::str() const std::string vectorize::str() const
{ {
...@@ -45,7 +45,7 @@ std::string vectorize::str() const ...@@ -45,7 +45,7 @@ std::string vectorize::str() const
} }
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result;
std::transform(inputs.begin(), std::transform(inputs.begin(),
...@@ -67,7 +67,7 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) ...@@ -67,7 +67,7 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
// TODO: Try to partially preload items // TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false); std::fill(result.begin(), result.end(), false);
return {result}; return {result};
} }
std::string preload::str() const std::string preload::str() const
{ {
......
...@@ -33,7 +33,7 @@ std::size_t find_fast_axis(const std::vector<shape>& inputs); ...@@ -33,7 +33,7 @@ std::size_t find_fast_axis(const std::vector<shape>& inputs);
std::string make_transformer_args(std::vector<std::string> transformers); std::string make_transformer_args(std::vector<std::string> transformers);
template<class... Ts> template <class... Ts>
std::string make_transformer_args(Ts... xs) std::string make_transformer_args(Ts... xs)
{ {
return make_transformer_args({xs.str()...}); return make_transformer_args({xs.str()...});
......
...@@ -63,7 +63,8 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -63,7 +63,8 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, inputs); auto preloads = preload::broadcasts(axis, inputs);
options.set_launch_params(v, options.set_launch_params(
v,
compute_global_for(ctx, compute_global_for(ctx,
options.output.elements() / vec.size, options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading()))); oversubscribe_if(not preloads.is_preloading())));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment