Commit f1f60be1 authored by Paul's avatar Paul
Browse files

Fix vec issues

parent c13780c2
......@@ -30,6 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
if (len == 1)
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
......
......@@ -98,35 +98,36 @@ struct reduce_compiler : compiler<reduce_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto faxis = find_fast_axis({inputs.front()});
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(inputs.back().lens()[faxis] == 1)
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
vec = vectorize::elements(faxis, options.virtual_inputs);
}
auto reduce_elements = get_reduce_elements(inputs) / vec.size;
auto algo = v.get("algo", get_reduce_algo(inputs));
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
auto block_size = compute_block_size(reduce_elements, 256);
auto block_size = compute_block_size(relements, 256);
options.set_launch_params(
v,
compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256),
compute_global_for(ctx, nelements * block_size, 256),
block_size);
}
else if(algo == "lane")
{
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
......
......@@ -153,9 +153,9 @@ constexpr auto vec_reduce(T x, Op op)
return x;
else
{
vec_type<T> result;
vec_type<T> result = 0;
for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i - 1], result[i]);
result = op(x[i - 1], x[i]);
return result;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment