Commit db2def39 authored by Paul's avatar Paul
Browse files

Format

parent f1f60be1
...@@ -30,8 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -30,8 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(stride != 0 and stride != 1)
return 1; return 1;
if (len == 1) if(len == 1)
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end()) if(it != sizes.end())
......
...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1) if(options.virtual_inputs.back().lens()[faxis] == 1)
...@@ -110,27 +110,24 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -110,27 +110,24 @@ struct reduce_compiler : compiler<reduce_compiler>
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, nelements * block_size, 256), block_size);
compute_global_for(ctx, nelements * block_size, 256),
block_size);
} }
else if(algo == "lane") else if(algo == "lane")
{ {
options.set_launch_params( options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
v, compute_global_for(ctx, nelements, 256));
} }
else else
{ {
MIGRAPHX_THROW("Unknown reduce algo: " + algo); MIGRAPHX_THROW("Unknown reduce algo: " + algo);
} }
options.kernel_name = "reduce_kernel"; options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment