Commit 8a6ae079 authored by Paul's avatar Paul
Browse files

Format

parent d60364a3
......@@ -19,25 +19,25 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
});
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
std::string vectorize::str() const
{
......@@ -45,29 +45,29 @@ std::string vectorize::str() const
}
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return {result};
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return {result};
}
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return {result};
}
std::string preload::str() const
{
......
......@@ -33,7 +33,7 @@ std::size_t find_fast_axis(const std::vector<shape>& inputs);
std::string make_transformer_args(std::vector<std::string> transformers);
template<class... Ts>
template <class... Ts>
std::string make_transformer_args(Ts... xs)
{
return make_transformer_args({xs.str()...});
......
......@@ -61,12 +61,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, inputs);
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, inputs);
options.set_launch_params(
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment