Commit 50e256a1 authored by Paul's avatar Paul
Browse files

Always use block for now

parent fdd86cc4
...@@ -62,27 +62,6 @@ static std::size_t get_reduce_elements(const std::vector<shape>& inputs) ...@@ -62,27 +62,6 @@ static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{ {
return inputs.front().elements() / inputs.back().elements(); return inputs.front().elements() / inputs.back().elements();
} }
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
template <class T> template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
...@@ -93,9 +72,10 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) ...@@ -93,9 +72,10 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
return shape{s.type(), lens}; return shape{s.type(), lens};
} }
static std::string get_reduce_algo(const std::vector<shape>& inputs) template<class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{ {
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens()); #if 0
const auto init = std::numeric_limits<std::size_t>::max(); const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride // The minimum stride
auto min_stride = std::inner_product( auto min_stride = std::inner_product(
...@@ -107,6 +87,7 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs) ...@@ -107,6 +87,7 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
[](auto len, auto stride) { return len == 1 ? init : stride; }); [](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2) if(min_stride > 2)
return "lane"; return "lane";
#endif
return "block"; return "block";
} }
...@@ -136,7 +117,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -136,7 +117,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduced_shape.lens()));
if(algo == "block") if(algo == "block")
{ {
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment