layernorm_kernels.cu 11.2 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
#include "layernorm_kernels_impl.cuh"
#include "dispatch_utils.h"

void rms_norm(Tensor &out,    // [..., hidden_size]
              Tensor &input,  // [..., hidden_size]
              Tensor &weight, // [hidden_size]
              float epsilon,
              bool use_quant) {
Muyang Li's avatar
Muyang Li committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
        if (use_quant) {
            vllm::rms_norm_kernel<scalar_t, int8_t, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
                                                                                      input.data_ptr<scalar_t>(),
                                                                                      weight.data_ptr<scalar_t>(),
                                                                                      epsilon,
                                                                                      num_tokens,
                                                                                      hidden_size);
        } else {
            vllm::rms_norm_kernel<scalar_t, scalar_t, false><<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                                                         input.data_ptr<scalar_t>(),
                                                                                         weight.data_ptr<scalar_t>(),
                                                                                         epsilon,
                                                                                         num_tokens,
                                                                                         hidden_size);
        }
    });
Zhekai Zhang's avatar
Zhekai Zhang committed
31
32
33
}

void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon) {
Muyang Li's avatar
Muyang Li committed
34
35
36
37
38
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 256));
    block.x = 32 * ((block.x + 31) / 32);
Zhekai Zhang's avatar
Zhekai Zhang committed
39

Muyang Li's avatar
Muyang Li committed
40
    size_t size_shmem = input.scalar_size() * hidden_size;
Zhekai Zhang's avatar
Zhekai Zhang committed
41

Muyang Li's avatar
Muyang Li committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
        using T = typename packed_as<scalar_t, 2>::type;
        vllm::generalLayerNorm<T, half, true><<<grid, block, size_shmem, stream>>>(
            reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
            weight.valid() ? reinterpret_cast<T *>(weight.data_ptr<scalar_t>()) : nullptr,
            bias.valid() ? reinterpret_cast<T *>(bias.data_ptr<scalar_t>()) : nullptr,
            reinterpret_cast<T *>(out.data_ptr<scalar_t>()),
            epsilon,
            num_tokens,
            hidden_size,
            nullptr,
            nullptr,
            nullptr,
            true);
    });
Zhekai Zhang's avatar
Zhekai Zhang committed
58
59
}

Muyang Li's avatar
Muyang Li committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void rms_norm_general(Tensor &out,     // [..., hidden_size]
                      Tensor &input,   // [..., hidden_size]
                      Tensor &weight,  // [hidden_size]
                      Tensor &scaling, // [tokens] or [1]
                      float epsilon,
                      bool use_per_token_quant) {
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
    block.x = 32 * ((block.x + 31) / 32);

    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
        using T = scalar_t;
        if (use_per_token_quant) {
            // per-token
            vllm::generalLayerNorm<T, half>
                <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
                                             reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
                                             nullptr,
                                             nullptr,
                                             epsilon,
                                             num_tokens,
                                             hidden_size,
                                             nullptr,
                                             scaling.data_ptr<half>(),
                                             out.data_ptr<int8_t>(),
                                             false);
            // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
            // normed_output_quant, use_shmem
            // out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
            // weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
        } else {
            // per-tensor
            vllm::generalLayerNorm<T, half>
                <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
                                             reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
                                             nullptr,
                                             nullptr,
                                             epsilon,
                                             num_tokens,
                                             hidden_size,
                                             scaling.data_ptr<half>(),
                                             nullptr,
                                             out.data_ptr<int8_t>(),
                                             false);
        }
    });
Zhekai Zhang's avatar
Zhekai Zhang committed
109
110
}

Muyang Li's avatar
Muyang Li committed
111
112
113
114
115
116
117
118
119
120
121
122
void rms_norm_general_fuse_sum(Tensor &out,       // [..., hidden_size]
                               Tensor &input,     // [..., hidden_size]
                               Tensor &weight,    // [hidden_size]
                               Tensor &input_sum, // [tokens] or [1]
                               Tensor &scaling,   // [tokens] or [1]
                               float epsilon,
                               bool use_per_token_quant) {
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
    block.x = 32 * ((block.x + 31) / 32);
Zhekai Zhang's avatar
Zhekai Zhang committed
123

Muyang Li's avatar
Muyang Li committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm_fuse_sum", [&] {
        using T = scalar_t;
        if (use_per_token_quant) {
            // per-token
            vllm::generalLayerNorm_fuse_sum<T, half>
                <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
                                             reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
                                             nullptr,
                                             nullptr,
                                             epsilon,
                                             num_tokens,
                                             hidden_size,
                                             input_sum.data_ptr<half>(),
                                             nullptr,
                                             scaling.data_ptr<half>(),
                                             out.data_ptr<int8_t>(),
                                             false);
            // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
            // normed_output_quant, use_shmem
            // out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
            // weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
        } else {
            // per-tensor
            // Rasing error here
            // Not implemented per-tensor input_sum
            assert(false);

            vllm::generalLayerNorm_fuse_sum<T, half>
                <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
                                             reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
                                             nullptr,
                                             nullptr,
                                             epsilon,
                                             num_tokens,
                                             hidden_size,
                                             nullptr,
                                             scaling.data_ptr<half>(),
                                             nullptr,
                                             out.data_ptr<int8_t>(),
                                             false);
        }
    });
}
Zhekai Zhang's avatar
Zhekai Zhang committed
168

Muyang Li's avatar
Muyang Li committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
void invoke_dequant_add_residual_rms_norm_quant(Tensor &out,      // [..., hidden_size]
                                                Tensor &input,    // [..., hidden_size]
                                                Tensor &residual, // [..., hidden_size]
                                                Tensor &gamma,    // [hidden_size]
                                                half scale,
                                                float epsilon) {
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] {
        vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half, false>
            <<<grid, block, 0, stream>>>(input.data_ptr<int32_t>(),
                                         residual.data_ptr<scalar_t>(),
                                         out.data_ptr<int8_t>(),
                                         gamma.data_ptr<scalar_t>(),
                                         epsilon,
                                         scale,
                                         num_tokens,
                                         hidden_size);
    });
Zhekai Zhang's avatar
Zhekai Zhang committed
191
192
}

Muyang Li's avatar
Muyang Li committed
193
194
195
196
197
198
199
200
void invoke_dequant_add_residual_rms_norm_quant(Tensor &out,      // [..., hidden_size]
                                                Tensor &input,    // [..., hidden_size]
                                                Tensor &residual, // [..., hidden_size]
                                                Tensor &gamma,    // [hidden_size]
                                                Tensor &scale,    // [num_tokens]
                                                float epsilon) {
    int hidden_size = input.size(-1);
    int num_tokens  = input.numel() / hidden_size;
Zhekai Zhang's avatar
Zhekai Zhang committed
201

Muyang Li's avatar
Muyang Li committed
202
203
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
Zhekai Zhang's avatar
Zhekai Zhang committed
204

Muyang Li's avatar
Muyang Li committed
205
206
207
208
209
210
211
212
213
214
215
216
    const cudaStream_t stream = getCurrentCUDAStream();
    VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] {
        vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half *, true>
            <<<grid, block, 0, stream>>>(input.data_ptr<int32_t>(),
                                         residual.data_ptr<scalar_t>(),
                                         out.data_ptr<int8_t>(),
                                         gamma.data_ptr<scalar_t>(),
                                         epsilon,
                                         scale.data_ptr<half>(),
                                         num_tokens,
                                         hidden_size);
    });
Zhekai Zhang's avatar
Zhekai Zhang committed
217
}