Commit de99db23 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the layernorm kernel arguments

parent 780fffc8
......@@ -176,7 +176,11 @@ template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_all_impl(get_shape(x),
auto sx = get_shape(x);
auto lens = sx.lens();
lens.back() /= N;
shape ssx{sx.type(), lens};
hip_visit_all_impl(ssx,
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f,
x,
......
......@@ -81,16 +81,14 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
}
template <index_int MaxBlockSize, class Input, class Output>
__device__ void layernorm(index_int i,
index idx,
std::size_t block_size_div,
__device__ void layernorm(index idx,
index_int relements,
Input input,
Output output)
{
using value_type = decltype(input(idx.local));
const auto relements_v = relements / vector_size<value_type>{};
const auto out_idx = fast_div(i, block_size_div);
const auto out_idx = blockIdx.x;
const auto base_idx = out_idx * relements_v;
const auto input_idx = base_idx + idx.local;
const bool in_range = idx.local < relements_v;
......@@ -133,14 +131,11 @@ void layernorm_vec_impl(hipStream_t stream,
const auto relements_v = relements / N;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements_v, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements_v <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream, nelements * block_size, block_size)([=](auto, auto idx) __device__ {
layernorm<max_block_size>(
i,
idx,
block_size_div,
relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) {
......@@ -162,14 +157,11 @@ void layernorm_impl(hipStream_t stream,
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream, nelements * block_size, block_size)([=](auto, auto idx) __device__ {
layernorm<max_block_size>(
i,
idx,
block_size_div,
relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) {
......@@ -188,10 +180,6 @@ auto layernorm_fusion(hipStream_t stream,
return [=](auto input, auto output) {
auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements;
// auto output_shape = result.get_shape();
// auto reduce_output_lens(output_shape.lens());
// reduce_output_lens.back() = 1;
if((relements % 4) == 0)
layernorm_vec_impl<4>(
stream, nelements, relements, input, output, result, arg1, args...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment