pos_encoding_kernels.cu 12.8 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4

5
#include "cuda_compat.h"
6
7
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
8
namespace vllm {
9

10
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
11
inline __device__ void apply_token_rotary_embedding(
12
13
    scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
    const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
14
15
16
17
18
19
  int x_index, y_index;
  scalar_t cos, sin;
  if (IS_NEOX) {
    // GPT-NeoX style rotary embedding.
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
20
21
    cos = VLLM_LDG(cos_ptr + x_index);
    sin = VLLM_LDG(sin_ptr + x_index);
22
23
24
25
  } else {
    // GPT-J style rotary embedding.
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
26
27
    cos = VLLM_LDG(cos_ptr + x_index / 2);
    sin = VLLM_LDG(sin_ptr + x_index / 2);
28
29
30
31
32
33
34
35
  }

  const scalar_t x = arr[x_index];
  const scalar_t y = arr[y_index];
  arr[x_index] = x * cos - y * sin;
  arr[y_index] = y * cos + x * sin;
}

36
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
37
inline __device__ void apply_rotary_embedding(
38
39
40
    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
41
42
    scalar_t* __restrict__ key,    // nullptr or
                                   // [batch_size, seq_len, num_kv_heads,
43
44
45
46
47
                                   // head_size] or [num_tokens, num_kv_heads,
                                   // head_size]
    const scalar_t* cache_ptr, const int head_size, const int num_heads,
    const int num_kv_heads, const int rot_dim, const int token_idx,
    const int64_t query_stride, const int64_t key_stride) {
48
  const int embed_dim = rot_dim / 2;
49
50
51
  const scalar_t* cos_ptr = cache_ptr;
  const scalar_t* sin_ptr = cache_ptr + embed_dim;

Zhuohan Li's avatar
Zhuohan Li committed
52
53
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
54
    const int head_idx = i / embed_dim;
55
    const int64_t token_head = token_idx * query_stride + head_idx * head_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
56
    const int rot_offset = i % embed_dim;
57
58
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
        query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
Zhuohan Li's avatar
Zhuohan Li committed
59
60
  }

61
62
63
64
65
66
67
68
69
  if (key != nullptr) {
    const int nk = num_kv_heads * embed_dim;
    for (int i = threadIdx.x; i < nk; i += blockDim.x) {
      const int head_idx = i / embed_dim;
      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
      const int rot_offset = i % embed_dim;
      apply_token_rotary_embedding<scalar_t, IS_NEOX>(
          key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
    }
70
71
72
  }
}

73
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
74
__global__ void rotary_embedding_kernel(
75
76
77
78
79
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
80
81
    scalar_t* __restrict__ key,  // nullptr or
                                 // [batch_size, seq_len, num_kv_heads,
82
83
84
85
86
87
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size) {
Terry's avatar
Terry committed
88
89
90
91
92
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

93
94
95
  apply_rotary_embedding<scalar_t, IS_NEOX>(
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
      token_idx, query_stride, key_stride);
Terry's avatar
Terry committed
96
97
}

98
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
99
__global__ void batched_rotary_embedding_kernel(
100
101
102
103
104
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
105
106
    scalar_t* __restrict__ key,  // nullptr or
                                 // [batch_size, seq_len, num_kv_heads,
107
108
109
110
111
112
113
114
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
    const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len]
                                                        // or [num_tokens]
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size) {
Terry's avatar
Terry committed
115
116
117
118
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
119
120
  const scalar_t* cache_ptr =
      cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
Terry's avatar
Terry committed
121

122
123
124
  apply_rotary_embedding<scalar_t, IS_NEOX>(
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
      token_idx, query_stride, key_stride);
Terry's avatar
Terry committed
125
126
}

127
}  // namespace vllm
128

129
void rotary_embedding(
130
131
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
132
133
134
                           // [num_tokens, num_heads * head_size] or
                           // [batch_size, seq_len, num_heads, head_size] or
                           // [num_tokens, num_heads, head_size]
135
136
137
138
139
140
    std::optional<torch::Tensor> key,
    // null or
    // [batch_size, seq_len, num_kv_heads * head_size] or
    // [num_tokens, num_kv_heads * head_size] or
    // [batch_size, seq_len, num_heads, head_size] or
    // [num_tokens, num_heads, head_size]
141
    int64_t head_size,
142
143
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
    bool is_neox) {
144
145
146
147
  // num_tokens = batch_size * seq_len
  int64_t num_tokens = positions.numel();
  int positions_ndim = positions.dim();

148
  // Make sure num_tokens dim is consistent across positions, query, and key
149
150
151
152
  TORCH_CHECK(
      positions_ndim == 1 || positions_ndim == 2,
      "positions must have shape [num_tokens] or [batch_size, seq_len]");
  if (positions_ndim == 1) {
153
154
155
    TORCH_CHECK(query.size(0) == positions.size(0) &&
                    (!key.has_value() || key->size(0) == positions.size(0)),
                "query, key and positions must have the same number of tokens");
156
157
158
159
  }
  if (positions_ndim == 2) {
    TORCH_CHECK(
        query.size(0) == positions.size(0) &&
160
            (!key.has_value() || key->size(0) == positions.size(0)) &&
161
            query.size(1) == positions.size(1) &&
162
            (!key.has_value() || key->size(1) == positions.size(1)),
163
164
165
166
167
168
        "query, key and positions must have the same batch_size and seq_len");
  }

  // Make sure head_size is valid for query and key
  // hidden_size = num_heads * head_size
  int query_hidden_size = query.numel() / num_tokens;
169
  int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
170
171
172
173
174
  TORCH_CHECK(query_hidden_size % head_size == 0);
  TORCH_CHECK(key_hidden_size % head_size == 0);

  // Make sure query and key have consistent number of heads
  int num_heads = query_hidden_size / head_size;
175
  int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
176
177
  TORCH_CHECK(num_heads % num_kv_heads == 0);

178
  int rot_dim = cos_sin_cache.size(1);
179
180
  int seq_dim_idx = positions_ndim - 1;
  int64_t query_stride = query.stride(seq_dim_idx);
181
  int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
182
183

  dim3 grid(num_tokens);
184
  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
185
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
186
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
187
188
189
190
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
    if (is_neox) {
      vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
191
192
193
          key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
          cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
          num_heads, num_kv_heads, head_size);
194
195
196
197
    } else {
      vllm::rotary_embedding_kernel<scalar_t, false>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
198
199
200
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
              cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size);
201
202
    }
  });
203
}
Terry's avatar
Terry committed
204
205
206
207
208
209

/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
210
211
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
212
213
214
                           // [num_tokens, num_heads * head_size] or
                           // [batch_size, seq_len, num_heads, head_size] or
                           // [num_tokens, num_heads, head_size]
215
216
217
218
219
220
    std::optional<torch::Tensor>
        key,  // null or
              // [batch_size, seq_len, num_kv_heads * head_size] or
              // [num_tokens, num_kv_heads * head_size] or
              // [batch_size, seq_len, num_heads, head_size] or
              // [num_tokens, num_heads, head_size]
221
    int64_t head_size,
222
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
223
    bool is_neox, int64_t rot_dim,
224
    torch::Tensor& cos_sin_cache_offsets  // [num_tokens] or [batch_size]
Terry's avatar
Terry committed
225
) {
226
  // num_tokens = batch_size * seq_len
Terry's avatar
Terry committed
227
  int64_t num_tokens = cos_sin_cache_offsets.size(0);
228
229
230
231
232
233
  TORCH_CHECK(
      positions.size(0) == num_tokens || positions.numel() == num_tokens,
      "positions must have the same num_tokens or batch_size as "
      "cos_sin_cache_offsets");

  int positions_ndim = positions.dim();
234
  // Make sure num_tokens dim is consistent across positions, query, and key
235
236
237
238
  TORCH_CHECK(
      positions_ndim == 1 || positions_ndim == 2,
      "positions must have shape [num_tokens] or [batch_size, seq_len]");
  if (positions_ndim == 1) {
239
240
241
    TORCH_CHECK(query.size(0) == positions.size(0) &&
                    (!key.has_value() || key->size(0) == positions.size(0)),
                "query, key and positions must have the same number of tokens");
242
243
244
245
  }
  if (positions_ndim == 2) {
    TORCH_CHECK(
        query.size(0) == positions.size(0) &&
246
            (!key.has_value() || key->size(0) == positions.size(0)) &&
247
            query.size(1) == positions.size(1) &&
248
            (!key.has_value() || key->size(1) == positions.size(1)),
249
250
251
252
253
        "query, key and positions must have the same batch_size and seq_len");
  }

  // Make sure head_size is valid for query and key
  int query_hidden_size = query.numel() / num_tokens;
254
  int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
255
256
257
258
259
  TORCH_CHECK(query_hidden_size % head_size == 0);
  TORCH_CHECK(key_hidden_size % head_size == 0);

  // Make sure query and key have concistent number of heads
  int num_heads = query_hidden_size / head_size;
260
  int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
261
262
263
264
  TORCH_CHECK(num_heads % num_kv_heads == 0);

  int seq_dim_idx = positions_ndim - 1;
  int64_t query_stride = query.stride(seq_dim_idx);
265
  int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
Terry's avatar
Terry committed
266
267

  dim3 grid(num_tokens);
268
  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Terry's avatar
Terry committed
269
270
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
271
272
273
274
275
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
    if (is_neox) {
      vllm::batched_rotary_embedding_kernel<scalar_t, true>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
276
277
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
              cos_sin_cache.data_ptr<scalar_t>(),
278
279
280
281
282
283
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size);
    } else {
      vllm::batched_rotary_embedding_kernel<scalar_t, false>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
284
285
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
              cos_sin_cache.data_ptr<scalar_t>(),
286
287
288
289
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size);
    }
  });
Terry's avatar
Terry committed
290
}