"vscode:/vscode.git/clone" did not exist on "ab3c9971cf132fe89ac5a0b95d4dbf9996cb1411"
Commit 50e256a1 authored by Paul's avatar Paul
Browse files

Always use block for now

parent fdd86cc4
......@@ -62,27 +62,6 @@ static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
......@@ -93,9 +72,10 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
return shape{s.type(), lens};
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
template<class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
#if 0
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
......@@ -107,6 +87,7 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
#endif
return "block";
}
......@@ -136,7 +117,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduced_shape.lens()));
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment