llama_decoder_kernels.cu 6.58 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

Chen Xin's avatar
Chen Xin committed
3
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
4
5
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
6
#include <cooperative_groups.h>
xiabo's avatar
xiabo committed
7
// #include <cooperative_groups/reduce.h>
Li Zhang's avatar
Li Zhang committed
8
9
10
11
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

lvhan028's avatar
lvhan028 committed
12
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
13
14

template<typename T>
AllentDan's avatar
AllentDan committed
15
16
struct res_norm_ops_t {
};
Li Zhang's avatar
Li Zhang committed
17
18
19
20

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
Li Zhang's avatar
Li Zhang committed
21
    __device__ uint4  addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
22
23
    {
        uint4 c;
Li Zhang's avatar
Li Zhang committed
24
25
26
27
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
Li Zhang's avatar
Li Zhang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
Li Zhang's avatar
Li Zhang committed
52
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
53
    {
Li Zhang's avatar
Li Zhang committed
54
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
Li Zhang's avatar
Li Zhang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
Li Zhang's avatar
Li Zhang committed
74
    __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
75
    {
Li Zhang's avatar
Li Zhang committed
76
        float c = a + b + bias;
Li Zhang's avatar
Li Zhang committed
77
78
79
80
81
82
83
84
85
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

xiabo's avatar
xiabo committed
86
87
88
89
// template<typename T>
// __device__ T blockReduceSum(const cg::thread_block& block, T value)
// {
//     __shared__ float partial[32];
Li Zhang's avatar
Li Zhang committed
90

xiabo's avatar
xiabo committed
91
92
//     auto tile = cg::tiled_partition<32>(block);
//     value     = cg::reduce(tile, value, cg::plus<float>{});
Li Zhang's avatar
Li Zhang committed
93

xiabo's avatar
xiabo committed
94
95
96
97
98
//     if (tile.thread_rank() == 0) {
//         partial[tile.meta_group_rank()] = value;
//     }

//     block.sync();
Li Zhang's avatar
Li Zhang committed
99

xiabo's avatar
xiabo committed
100
101
102
103
//     value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
//     return cg::reduce(tile, value, cg::plus<float>{});
// }
#define WARPSIZE 64
Li Zhang's avatar
Li Zhang committed
104

xiabo's avatar
xiabo committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
template<typename T>
__inline__ __device__ T warpReduceSum_xiabo(T value)
{
#pragma unroll
    for (int offset = WARPSIZE / 2; offset > 0; offset >>= 1)
        value += __shfl_down_sync(0xffffffff, value, offset);
    return value;
}

template<typename T>
__inline__ __device__ T blockReduceSum_xiabo(T val)
{
    T sum = (T)(0.0f);
    __shared__ T shared[WARPSIZE];
    sum = warpReduceSum_xiabo(val);
    __syncthreads();
    int tid = threadIdx.x + threadIdx.y * blockDim.x;
    if (tid % WARPSIZE == 0) {
        shared[tid / WARPSIZE] = sum;
    }
    if (tid >= blockDim.x * blockDim.y / WARPSIZE && tid < WARPSIZE) {
        shared[tid] = (T)(0.0f);
    }
    __syncthreads();
    if (tid / WARPSIZE == 0) {
        sum = warpReduceSum_xiabo(shared[tid]);
        if (tid == 0) {
            shared[0] = sum;
        }
    }
    __syncthreads();
    return shared[0];
Li Zhang's avatar
Li Zhang committed
137
138
139
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
140
141
142
143
144
145
146
__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
                                         T* __restrict__ x_data,
                                         const T* __restrict__ bias,
                                         const T* __restrict__ scale,
                                         float eps,
                                         int   batch_size,
                                         int   n_dims)
Li Zhang's avatar
Li Zhang committed
147
148
{
    auto block = cg::this_thread_block();
xiabo's avatar
xiabo committed
149
    // auto grid  = cg::this_grid();
Li Zhang's avatar
Li Zhang committed
150
151
152

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

153
    const auto batch_idx            = block.group_index().x;
Li Zhang's avatar
Li Zhang committed
154
155
156
    uint4* __restrict__ r_ptr       = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
    uint4* __restrict__ x_ptr       = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
    const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
Li Zhang's avatar
Li Zhang committed
157
158
159
160

    res_norm_t<T> ops;

    float thread_sum{};
161
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
162
163
164
165
        auto  r  = r_ptr[i];
        auto  x  = x_ptr[i];
        uint4 b  = b_ptr ? b_ptr[i] : uint4{};
        r        = ops.addvec(r, x, b, thread_sum);
Li Zhang's avatar
Li Zhang committed
166
167
168
        r_ptr[i] = r;
    }

xiabo's avatar
xiabo committed
169
170
    // auto total_sum = blockReduceSum(block, thread_sum);
    auto total_sum = blockReduceSum_xiabo(thread_sum);
Li Zhang's avatar
Li Zhang committed
171
172
173
174

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
175
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) {
Li Zhang's avatar
Li Zhang committed
176
177
178
179
180
181
182
183
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
184
void invokeFusedAddBiasResidualRMSNorm(
AllentDan's avatar
AllentDan committed
185
    T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
186
187
188
189
190
191
192
193
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

Li Zhang's avatar
Li Zhang committed
194
    fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
AllentDan's avatar
AllentDan committed
195
        residual, in_out, bias, scale, eps, batch_size, n_dims);
Li Zhang's avatar
Li Zhang committed
196
197
}

Li Zhang's avatar
Li Zhang committed
198
199
200
template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
Li Zhang's avatar
Li Zhang committed
201

lvhan028's avatar
lvhan028 committed
202
}  // namespace turbomind