llama_decoder_kernels.cu 5.58 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

lvhan028's avatar
lvhan028 committed
3
4
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
5
6
7
8
9
10
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

lvhan028's avatar
lvhan028 committed
11
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
12
13

template<typename T>
AllentDan's avatar
AllentDan committed
14
15
struct res_norm_ops_t {
};
Li Zhang's avatar
Li Zhang committed
16
17
18
19

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
Li Zhang's avatar
Li Zhang committed
20
    __device__ uint4  addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
21
22
    {
        uint4 c;
Li Zhang's avatar
Li Zhang committed
23
24
25
26
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
Li Zhang's avatar
Li Zhang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
Li Zhang's avatar
Li Zhang committed
51
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
52
    {
Li Zhang's avatar
Li Zhang committed
53
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
Li Zhang's avatar
Li Zhang committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
Li Zhang's avatar
Li Zhang committed
73
    __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
74
    {
Li Zhang's avatar
Li Zhang committed
75
        float c = a + b + bias;
Li Zhang's avatar
Li Zhang committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value)
{
    __shared__ float partial[32];

    auto tile = cg::tiled_partition<32>(block);
    value     = cg::reduce(tile, value, cg::plus<float>{});

    if (tile.thread_rank() == 0) {
        partial[tile.meta_group_rank()] = value;
    }

    block.sync();

    value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
    return cg::reduce(tile, value, cg::plus<float>{});
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
104
105
106
107
108
109
110
__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
                                         T* __restrict__ x_data,
                                         const T* __restrict__ bias,
                                         const T* __restrict__ scale,
                                         float eps,
                                         int   batch_size,
                                         int   n_dims)
Li Zhang's avatar
Li Zhang committed
111
112
113
114
115
116
{
    auto block = cg::this_thread_block();
    auto grid  = cg::this_grid();

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

Li Zhang's avatar
Li Zhang committed
117
118
119
120
    const auto batch_idx            = grid.block_rank();
    uint4* __restrict__ r_ptr       = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
    uint4* __restrict__ x_ptr       = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
    const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
Li Zhang's avatar
Li Zhang committed
121
122
123
124
125

    res_norm_t<T> ops;

    float thread_sum{};
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
Li Zhang's avatar
Li Zhang committed
126
127
128
129
        auto  r  = r_ptr[i];
        auto  x  = x_ptr[i];
        uint4 b  = b_ptr ? b_ptr[i] : uint4{};
        r        = ops.addvec(r, x, b, thread_sum);
Li Zhang's avatar
Li Zhang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        r_ptr[i] = r;
    }

    auto total_sum = blockReduceSum(block, thread_sum);

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
147
void invokeFusedAddBiasResidualRMSNorm(
AllentDan's avatar
AllentDan committed
148
    T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
149
150
151
152
153
154
155
156
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

Li Zhang's avatar
Li Zhang committed
157
    fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
AllentDan's avatar
AllentDan committed
158
        residual, in_out, bias, scale, eps, batch_size, n_dims);
Li Zhang's avatar
Li Zhang committed
159
160
}

Li Zhang's avatar
Li Zhang committed
161
162
163
template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
Li Zhang's avatar
Li Zhang committed
164

lvhan028's avatar
lvhan028 committed
165
}  // namespace turbomind