llama_decoder_kernels.cu 7.61 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

Chen Xin's avatar
Chen Xin committed
3
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
4
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
q.yao's avatar
q.yao committed
5
#include "src/turbomind/utils/cuda_type_utils.cuh"
lvhan028's avatar
lvhan028 committed
6
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
7
#include <cooperative_groups.h>
8
// #include <cooperative_groups/reduce.h>
Li Zhang's avatar
Li Zhang committed
9
10
11
12
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

lvhan028's avatar
lvhan028 committed
13
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
14
15

template<typename T>
AllentDan's avatar
AllentDan committed
16
17
struct res_norm_ops_t {
};
Li Zhang's avatar
Li Zhang committed
18
19
20
21

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
Li Zhang's avatar
Li Zhang committed
22
    __device__ uint4  addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
23
24
    {
        uint4 c;
Li Zhang's avatar
Li Zhang committed
25
26
27
28
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
Li Zhang's avatar
Li Zhang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
Li Zhang's avatar
Li Zhang committed
53
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
54
    {
Li Zhang's avatar
Li Zhang committed
55
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
Li Zhang's avatar
Li Zhang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
Li Zhang's avatar
Li Zhang committed
75
    __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
76
    {
Li Zhang's avatar
Li Zhang committed
77
        float c = a + b + bias;
Li Zhang's avatar
Li Zhang committed
78
79
80
81
82
83
84
85
86
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

q.yao's avatar
q.yao committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#ifdef ENABLE_BF16
template<>
struct res_norm_ops_t<__nv_bfloat16> {
    __device__ float2 cast(const uint& x) const
    {
        return cuda_cast<float2, __nv_bfloat162>(reinterpret_cast<const __nv_bfloat162&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = cuda_cast<__nv_bfloat162, float2>(x);
        return reinterpret_cast<uint&>(y);
    }
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
    {
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

#endif

xiabo's avatar
xiabo committed
113
114
115
116
// template<typename T>
// __device__ T blockReduceSum(const cg::thread_block& block, T value)
// {
//     __shared__ float partial[32];
Li Zhang's avatar
Li Zhang committed
117

xiabo's avatar
xiabo committed
118
119
//     auto tile = cg::tiled_partition<32>(block);
//     value     = cg::reduce(tile, value, cg::plus<float>{});
Li Zhang's avatar
Li Zhang committed
120

xiabo's avatar
xiabo committed
121
122
123
//     if (tile.thread_rank() == 0) {
//         partial[tile.meta_group_rank()] = value;
//     }
Li Zhang's avatar
Li Zhang committed
124

xiabo's avatar
xiabo committed
125
//     block.sync();
Li Zhang's avatar
Li Zhang committed
126

xiabo's avatar
xiabo committed
127
128
129
130
131
132
133
134
135
136
137
138
//     value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
//     return cg::reduce(tile, value, cg::plus<float>{});
// }
#define WARPSIZE 64

template<typename T>
__inline__ __device__ T warpReduceSum_xiabo(T value)
{
#pragma unroll
    for (int offset = WARPSIZE / 2; offset > 0; offset >>= 1)
        value += __shfl_down_sync(0xffffffff, value, offset);
    return value;
Li Zhang's avatar
Li Zhang committed
139
140
}

xiabo's avatar
xiabo committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
template<typename T>
__inline__ __device__ T blockReduceSum_xiabo(T val)
{
    T sum = (T)(0.0f);
    __shared__ T shared[WARPSIZE];
    sum = warpReduceSum_xiabo(val);
    __syncthreads();
    int tid = threadIdx.x + threadIdx.y * blockDim.x;
    if (tid % WARPSIZE == 0) {
        shared[tid / WARPSIZE] = sum;
    }
    if (tid >= blockDim.x * blockDim.y / WARPSIZE && tid < WARPSIZE) {
        shared[tid] = (T)(0.0f);
    }
    __syncthreads();
    if (tid / WARPSIZE == 0) {
        sum = warpReduceSum_xiabo(shared[tid]);
        if (tid == 0) {
            shared[0] = sum;
        }
    }
    __syncthreads();
    return shared[0];
}
165
166
// r' = r + x
// x' = norm(r') * scales
Li Zhang's avatar
Li Zhang committed
167
template<typename T>
Li Zhang's avatar
Li Zhang committed
168
169
170
171
172
173
174
__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
                                         T* __restrict__ x_data,
                                         const T* __restrict__ bias,
                                         const T* __restrict__ scale,
                                         float eps,
                                         int   batch_size,
                                         int   n_dims)
Li Zhang's avatar
Li Zhang committed
175
176
{
    auto block = cg::this_thread_block();
xiabo's avatar
xiabo committed
177
    // auto grid  = cg::this_grid();
Li Zhang's avatar
Li Zhang committed
178
179
180

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

181
    const auto batch_idx            = block.group_index().x;
Li Zhang's avatar
Li Zhang committed
182
183
184
    uint4* __restrict__ r_ptr       = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
    uint4* __restrict__ x_ptr       = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
    const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
Li Zhang's avatar
Li Zhang committed
185
186
187
188

    res_norm_t<T> ops;

    float thread_sum{};
189
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
190
191
192
193
        auto  r  = r_ptr[i];
        auto  x  = x_ptr[i];
        uint4 b  = b_ptr ? b_ptr[i] : uint4{};
        r        = ops.addvec(r, x, b, thread_sum);
Li Zhang's avatar
Li Zhang committed
194
195
196
        r_ptr[i] = r;
    }

xiabo's avatar
xiabo committed
197
198
    // auto total_sum = blockReduceSum(block, thread_sum);
    auto total_sum = blockReduceSum_xiabo(thread_sum);
Li Zhang's avatar
Li Zhang committed
199
200
201
202

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
203
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
204
205
206
207
208
209
210
211
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
212
void invokeFusedAddBiasResidualRMSNorm(
AllentDan's avatar
AllentDan committed
213
    T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
214
215
216
217
218
219
220
221
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

Li Zhang's avatar
Li Zhang committed
222
    fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
AllentDan's avatar
AllentDan committed
223
        residual, in_out, bias, scale, eps, batch_size, n_dims);
Li Zhang's avatar
Li Zhang committed
224
225
}

Li Zhang's avatar
Li Zhang committed
226
227
228
template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
q.yao's avatar
q.yao committed
229
230
231
232
#ifdef ENABLE_BF16
template void invokeFusedAddBiasResidualRMSNorm(
    __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);
#endif
lvhan028's avatar
lvhan028 committed
233
}  // namespace turbomind