Commit de99db23 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the layernorm kernel arguments

parent 780fffc8
...@@ -176,7 +176,11 @@ template <index_int N, class T, class... Ts> ...@@ -176,7 +176,11 @@ template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_all_impl(get_shape(x), auto sx = get_shape(x);
auto lens = sx.lens();
lens.back() /= N;
shape ssx{sx.type(), lens};
hip_visit_all_impl(ssx,
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f, f,
x, x,
......
...@@ -81,16 +81,14 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -81,16 +81,14 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
} }
template <index_int MaxBlockSize, class Input, class Output> template <index_int MaxBlockSize, class Input, class Output>
__device__ void layernorm(index_int i, __device__ void layernorm(index idx,
index idx,
std::size_t block_size_div,
index_int relements, index_int relements,
Input input, Input input,
Output output) Output output)
{ {
using value_type = decltype(input(idx.local)); using value_type = decltype(input(idx.local));
const auto relements_v = relements / vector_size<value_type>{}; const auto relements_v = relements / vector_size<value_type>{};
const auto out_idx = fast_div(i, block_size_div); const auto out_idx = blockIdx.x;
const auto base_idx = out_idx * relements_v; const auto base_idx = out_idx * relements_v;
const auto input_idx = base_idx + idx.local; const auto input_idx = base_idx + idx.local;
const bool in_range = idx.local < relements_v; const bool in_range = idx.local < relements_v;
...@@ -133,14 +131,11 @@ void layernorm_vec_impl(hipStream_t stream, ...@@ -133,14 +131,11 @@ void layernorm_vec_impl(hipStream_t stream,
const auto relements_v = relements / N; const auto relements_v = relements / N;
const std::size_t max_block_size = 256; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements_v, max_block_size); const std::size_t block_size = compute_block_size(relements_v, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements_v <= block_size); assert(relements_v <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto, auto idx) __device__ {
layernorm<max_block_size>( layernorm<max_block_size>(
i,
idx, idx,
block_size_div,
relements, relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); }, [&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) { [&](auto input_idx, auto x) {
...@@ -162,14 +157,11 @@ void layernorm_impl(hipStream_t stream, ...@@ -162,14 +157,11 @@ void layernorm_impl(hipStream_t stream,
hip_visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
const std::size_t max_block_size = 256; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements <= block_size); assert(relements <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto, auto idx) __device__ {
layernorm<max_block_size>( layernorm<max_block_size>(
i,
idx, idx,
block_size_div,
relements, relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); }, [&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) { [&](auto input_idx, auto x) {
...@@ -188,10 +180,6 @@ auto layernorm_fusion(hipStream_t stream, ...@@ -188,10 +180,6 @@ auto layernorm_fusion(hipStream_t stream,
return [=](auto input, auto output) { return [=](auto input, auto output) {
auto relements = arg1.get_shape().lens().back(); auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements; auto nelements = result.get_shape().elements() / relements;
// auto output_shape = result.get_shape();
// auto reduce_output_lens(output_shape.lens());
// reduce_output_lens.back() = 1;
if((relements % 4) == 0) if((relements % 4) == 0)
layernorm_vec_impl<4>( layernorm_vec_impl<4>(
stream, nelements, relements, input, output, result, arg1, args...); stream, nelements, relements, input, output, result, arg1, args...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment