llama_decoder_kernels.cu 6.63 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

Chen Xin's avatar
Chen Xin committed
3
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
4
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
q.yao's avatar
q.yao committed
5
#include "src/turbomind/utils/cuda_type_utils.cuh"
lvhan028's avatar
lvhan028 committed
6
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
7
8
9
10
11
12
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

lvhan028's avatar
lvhan028 committed
13
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
14
15

template<typename T>
AllentDan's avatar
AllentDan committed
16
17
struct res_norm_ops_t {
};
Li Zhang's avatar
Li Zhang committed
18
19
20
21

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
Li Zhang's avatar
Li Zhang committed
22
    __device__ uint4  addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
23
24
    {
        uint4 c;
Li Zhang's avatar
Li Zhang committed
25
26
27
28
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
Li Zhang's avatar
Li Zhang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
Li Zhang's avatar
Li Zhang committed
53
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
54
    {
Li Zhang's avatar
Li Zhang committed
55
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
Li Zhang's avatar
Li Zhang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
Li Zhang's avatar
Li Zhang committed
75
    __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
76
    {
Li Zhang's avatar
Li Zhang committed
77
        float c = a + b + bias;
Li Zhang's avatar
Li Zhang committed
78
79
80
81
82
83
84
85
86
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

q.yao's avatar
q.yao committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#ifdef ENABLE_BF16
template<>
struct res_norm_ops_t<__nv_bfloat16> {
    __device__ float2 cast(const uint& x) const
    {
        return cuda_cast<float2, __nv_bfloat162>(reinterpret_cast<const __nv_bfloat162&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = cuda_cast<__nv_bfloat162, float2>(x);
        return reinterpret_cast<uint&>(y);
    }
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
    {
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

#endif

Li Zhang's avatar
Li Zhang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value)
{
    __shared__ float partial[32];

    auto tile = cg::tiled_partition<32>(block);
    value     = cg::reduce(tile, value, cg::plus<float>{});

    if (tile.thread_rank() == 0) {
        partial[tile.meta_group_rank()] = value;
    }

    block.sync();

    value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
    return cg::reduce(tile, value, cg::plus<float>{});
}

131
132
// r' = r + x
// x' = norm(r') * scales
Li Zhang's avatar
Li Zhang committed
133
template<typename T>
Li Zhang's avatar
Li Zhang committed
134
135
136
137
138
139
140
__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
                                         T* __restrict__ x_data,
                                         const T* __restrict__ bias,
                                         const T* __restrict__ scale,
                                         float eps,
                                         int   batch_size,
                                         int   n_dims)
Li Zhang's avatar
Li Zhang committed
141
142
143
144
145
146
{
    auto block = cg::this_thread_block();
    auto grid  = cg::this_grid();

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

147
    const auto batch_idx            = block.group_index().x;
Li Zhang's avatar
Li Zhang committed
148
149
150
    uint4* __restrict__ r_ptr       = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
    uint4* __restrict__ x_ptr       = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
    const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
Li Zhang's avatar
Li Zhang committed
151
152
153
154

    res_norm_t<T> ops;

    float thread_sum{};
155
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
156
157
158
159
        auto  r  = r_ptr[i];
        auto  x  = x_ptr[i];
        uint4 b  = b_ptr ? b_ptr[i] : uint4{};
        r        = ops.addvec(r, x, b, thread_sum);
Li Zhang's avatar
Li Zhang committed
160
161
162
163
164
165
166
167
        r_ptr[i] = r;
    }

    auto total_sum = blockReduceSum(block, thread_sum);

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
168
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
169
170
171
172
173
174
175
176
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
177
void invokeFusedAddBiasResidualRMSNorm(
AllentDan's avatar
AllentDan committed
178
    T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
179
180
181
182
183
184
185
186
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

Li Zhang's avatar
Li Zhang committed
187
    fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
AllentDan's avatar
AllentDan committed
188
        residual, in_out, bias, scale, eps, batch_size, n_dims);
Li Zhang's avatar
Li Zhang committed
189
190
}

Li Zhang's avatar
Li Zhang committed
191
192
193
template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
q.yao's avatar
q.yao committed
194
195
196
197
#ifdef ENABLE_BF16
template void invokeFusedAddBiasResidualRMSNorm(
    __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);
#endif
lvhan028's avatar
lvhan028 committed
198
}  // namespace turbomind