llama_decoder_kernels.cu 5.61 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/fastertransformer/models/llama/llama_decoder_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>

namespace cg = cooperative_groups;

namespace fastertransformer {

template<typename T>
struct res_norm_ops_t {};

template<typename T>
struct res_norm_t {
    res_norm_ops_t<T> f;
Li Zhang's avatar
Li Zhang committed
19
    __device__ uint4  addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
20
21
    {
        uint4 c;
Li Zhang's avatar
Li Zhang committed
22
23
24
25
        c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
        c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
        c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
        c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
Li Zhang's avatar
Li Zhang committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        return c;
    }
    __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
    {
        uint4 v;
        v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor));
        v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor));
        v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor));
        v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor));
        return v;
    }
};

template<>
struct res_norm_ops_t<half> {
    __device__ float2 cast(const uint& x) const
    {
        return __half22float2(reinterpret_cast<const half2&>(x));
    }
    __device__ uint cast(const float2& x) const
    {
        auto y = __float22half2_rn(x);
        return reinterpret_cast<uint&>(y);
    }
Li Zhang's avatar
Li Zhang committed
50
    __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
51
    {
Li Zhang's avatar
Li Zhang committed
52
        float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
Li Zhang's avatar
Li Zhang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        accum += c.x * c.x + c.y * c.y;
        return c;
    }
    __device__ float2 norm(const float2& a, const float2& s, float factor) const
    {
        return {a.x * s.x * factor, a.y * s.y * factor};
    }
};

template<>
struct res_norm_ops_t<float> {
    __device__ float cast(const uint& x) const
    {
        return reinterpret_cast<const float&>(x);
    }
    __device__ uint cast(const float& x) const
    {
        return reinterpret_cast<const uint&>(x);
    }
Li Zhang's avatar
Li Zhang committed
72
    __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
Li Zhang's avatar
Li Zhang committed
73
    {
Li Zhang's avatar
Li Zhang committed
74
        float c = a + b + bias;
Li Zhang's avatar
Li Zhang committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        accum += c * c;
        return c;
    }
    __device__ float norm(const float& a, const float& s, float factor) const
    {
        return a * s * factor;
    }
};

template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value)
{
    __shared__ float partial[32];

    auto tile = cg::tiled_partition<32>(block);
    value     = cg::reduce(tile, value, cg::plus<float>{});

    if (tile.thread_rank() == 0) {
        partial[tile.meta_group_rank()] = value;
    }

    block.sync();

    value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
    return cg::reduce(tile, value, cg::plus<float>{});
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
103
104
105
106
107
108
109
__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
                                         T* __restrict__ x_data,
                                         const T* __restrict__ bias,
                                         const T* __restrict__ scale,
                                         float eps,
                                         int   batch_size,
                                         int   n_dims)
Li Zhang's avatar
Li Zhang committed
110
111
112
113
114
115
{
    auto block = cg::this_thread_block();
    auto grid  = cg::this_grid();

    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);

Li Zhang's avatar
Li Zhang committed
116
117
118
119
    const auto batch_idx            = grid.block_rank();
    uint4* __restrict__ r_ptr       = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
    uint4* __restrict__ x_ptr       = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
    const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
Li Zhang's avatar
Li Zhang committed
120
121
122
123
124

    res_norm_t<T> ops;

    float thread_sum{};
    for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
Li Zhang's avatar
Li Zhang committed
125
126
127
128
        auto  r  = r_ptr[i];
        auto  x  = x_ptr[i];
        uint4 b  = b_ptr ? b_ptr[i] : uint4{};
        r        = ops.addvec(r, x, b, thread_sum);
Li Zhang's avatar
Li Zhang committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        r_ptr[i] = r;
    }

    auto total_sum = blockReduceSum(block, thread_sum);

    float s_inv_mean = rsqrt(total_sum / n_dims + eps);

    const uint4* __restrict__ s_ptr = reinterpret_cast<const uint4*>(scale);
    for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
        auto r   = r_ptr[i];
        auto s   = s_ptr[i];
        auto o   = ops.normvec(r, s, s_inv_mean);
        x_ptr[i] = o;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
146
147
void invokeFusedAddBiasResidualRMSNorm(
    T* residual, T* inout, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
148
149
150
151
152
153
154
155
{
    constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
    FT_CHECK(n_dims % PACK_DIM == 0);
    const int n_pack    = n_dims / PACK_DIM;
    const int n_iter    = ((n_pack + 1023) / 1024);        // iterations when block size == 1024
    int       n_threads = (n_pack + n_iter - 1) / n_iter;  // adjust block size to avoid tail effect
    n_threads           = (n_threads + 31) / 32 * 32;      // round up to the nearest multiple of warp size

Li Zhang's avatar
Li Zhang committed
156
157
    fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
        residual, inout, bias, scale, eps, batch_size, n_dims);
Li Zhang's avatar
Li Zhang committed
158
159
}

Li Zhang's avatar
Li Zhang committed
160
161
162
template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
Li Zhang's avatar
Li Zhang committed
163
164

}  // namespace fastertransformer