llama_decoder_kernels.cu 4.97 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/fastertransformer/models/llama/llama_decoder_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

namespace fastertransformer {

template<typename T>
struct res_norm_ops_t {};

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
    __device__ uint4  addvec(const uint4& a, const uint4& b, float& accum) const
    {
        uint4 c;
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), accum));
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
    __device__ float2 add(const float2& a, const float2& b, float& accum) const
    {
        float2 c{a.x + b.x, a.y + b.y};
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
    __device__ float add(const float& a, const float& b, float& accum) const
    {
        float c = a + b;
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value)
{
    __shared__ float partial[32];

    auto tile = cg::tiled_partition<32>(block);
    value     = cg::reduce(tile, value, cg::plus<float>{});

    if (tile.thread_rank() == 0) {
        partial[tile.meta_group_rank()] = value;
    }

    block.sync();

    value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
    return cg::reduce(tile, value, cg::plus<float>{});
}

template<typename T>
__global__ void fusedAddResidualNorm(
    T* __restrict__ r_data, T* __restrict__ x_data, const T* __restrict__ scale, float eps, int batch_size, int n_dims)
{
    auto block = cg::this_thread_block();
    auto grid  = cg::this_grid();

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

    const auto b              = grid.block_rank();
    uint4* __restrict__ r_ptr = reinterpret_cast<uint4*>(r_data + b * n_dims);
    uint4* __restrict__ x_ptr = reinterpret_cast<uint4*>(x_data + b * n_dims);

    res_norm_t<T> ops;

    float thread_sum{};
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
        auto r   = r_ptr[i];
        auto x   = x_ptr[i];
        r        = ops.addvec(r, x, thread_sum);
        r_ptr[i] = r;
    }

    auto total_sum = blockReduceSum(block, thread_sum);

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
void invokeFusedAddResidualRMSNorm(
    T* residual, T* inout, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

    fusedAddResidualNorm<<<batch_size, n_threads, 0, stream>>>(residual, inout, scale, eps, batch_size, n_dims);
}

template void invokeFusedAddResidualRMSNorm(float*, float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddResidualRMSNorm(half*, half*, const half*, float, int, int, cudaStream_t);

}  // namespace fastertransformer