"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "298fc26b54ef827ec3261e7f84d976b7a347c4e2"
Commit c13780c2 authored by Paul's avatar Paul
Browse files

Format

parent 15fd8205
...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto faxis = find_fast_axis({inputs.front()}); auto faxis = find_fast_axis({inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if (inputs.back().lens()[faxis] == 1) if(inputs.back().lens()[faxis] == 1)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(faxis, inputs);
} }
...@@ -111,11 +111,14 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -111,11 +111,14 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
auto block_size = compute_block_size(reduce_elements, 256); auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256), block_size); v,
compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256),
block_size);
} }
else if(algo == "lane") else if(algo == "lane")
{ {
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256)); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
} }
else else
{ {
...@@ -124,7 +127,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -124,7 +127,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "reduce_kernel"; options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
......
...@@ -163,9 +163,12 @@ struct block ...@@ -163,9 +163,12 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return vec_reduce(block_reduce(idx,
return read(x[j], xs[j]...); op,
}), op); init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
......
...@@ -146,7 +146,7 @@ constexpr auto vec_packed_transform(Ts... xs) ...@@ -146,7 +146,7 @@ constexpr auto vec_packed_transform(Ts... xs)
}; };
} }
template<class T, class Op> template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
...@@ -155,7 +155,7 @@ constexpr auto vec_reduce(T x, Op op) ...@@ -155,7 +155,7 @@ constexpr auto vec_reduce(T x, Op op)
{ {
vec_type<T> result; vec_type<T> result;
for(int i = 1; i < vec_size<T>(); i++) for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i-1], result[i]); result = op(result[i - 1], result[i]);
return result; return result;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment