Commit db2def39 authored by Paul's avatar Paul
Browse files

Format

parent f1f60be1
......@@ -30,8 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
if (len == 1)
return sizes.front();
if(len == 1)
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
......
......@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()});
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
......@@ -110,27 +110,24 @@ struct reduce_compiler : compiler<reduce_compiler>
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
options.set_launch_params(
v,
compute_global_for(ctx, nelements * block_size, 256),
block_size);
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(
v, compute_global_for(ctx, nelements, 256));
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment