pos_encoding_kernels.cu 8.54 KB
Newer Older
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4

5
#include "cuda_compat.h"
6
7
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
8
namespace vllm {
9

10
template<typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
11
inline __device__ void apply_token_rotary_embedding(
12
13
14
15
16
17
18
19
20
21
22
23
  scalar_t* __restrict__ arr,
  const scalar_t* __restrict__ cos_ptr,
  const scalar_t* __restrict__ sin_ptr,
  int rot_offset,
  int embed_dim)
{
  int x_index, y_index;
  scalar_t cos, sin;
  if (IS_NEOX) {
    // GPT-NeoX style rotary embedding.
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
24
25
    cos = VLLM_LDG(cos_ptr + x_index);
    sin = VLLM_LDG(sin_ptr + x_index);
26
27
28
29
  } else {
    // GPT-J style rotary embedding.
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
30
31
    cos = VLLM_LDG(cos_ptr + x_index / 2);
    sin = VLLM_LDG(sin_ptr + x_index / 2);
32
33
34
35
36
37
38
39
40
  }

  const scalar_t x = arr[x_index];
  const scalar_t y = arr[y_index];
  arr[x_index] = x * cos - y * sin;
  arr[y_index] = y * cos + x * sin;
}

template<typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
41
inline __device__ void apply_rotary_embedding(
42
43
  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
Terry's avatar
Terry committed
44
45
  const scalar_t* cache_ptr,
  const int head_size,
46
  const int num_heads,
Zhuohan Li's avatar
Zhuohan Li committed
47
  const int num_kv_heads,
Terry's avatar
Terry committed
48
49
50
51
52
  const int rot_dim,
  const int token_idx,
  const int64_t query_stride,
  const int64_t key_stride)
{
53
  const int embed_dim = rot_dim / 2;
54
55
56
  const scalar_t* cos_ptr = cache_ptr;
  const scalar_t* sin_ptr = cache_ptr + embed_dim;

Zhuohan Li's avatar
Zhuohan Li committed
57
58
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
59
    const int head_idx = i / embed_dim;
60
    const int64_t token_head = token_idx * query_stride + head_idx * head_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
61
    const int rot_offset = i % embed_dim;
Terry's avatar
Terry committed
62
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
63
                                              sin_ptr, rot_offset, embed_dim);
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
  }

  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
69
    const int64_t token_head = token_idx * key_stride + head_idx * head_size;
Zhuohan Li's avatar
Zhuohan Li committed
70
    const int rot_offset = i % embed_dim;
Terry's avatar
Terry committed
71
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
72
                                              sin_ptr, rot_offset, embed_dim);
73
74
75
  }
}

Terry's avatar
Terry committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
  const int64_t* __restrict__ positions,        // [batch_size, seq_len] or [num_tokens]
  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
  const int rot_dim,
  const int64_t query_stride,
  const int64_t key_stride,
  const int num_heads,
  const int num_kv_heads,
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}

template<typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
  const int64_t* __restrict__ positions,              // [batch_size, seq_len] or [num_tokens]
  scalar_t* __restrict__ query,                       // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key,                         // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  const scalar_t* __restrict__ cos_sin_cache,         // [max_position, 2, rot_dim // 2]
  const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len] or [num_tokens]
  const int rot_dim,
  const int64_t query_stride,
  const int64_t key_stride,
  const int num_heads,
  const int num_kv_heads,
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;

  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}

Woosuk Kwon's avatar
Woosuk Kwon committed
118
} // namespace vllm
119

120
void rotary_embedding(
121
122
123
  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
124
  int head_size,
125
126
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
  bool is_neox) {
Antoni Baum's avatar
Antoni Baum committed
127
  int64_t num_tokens = query.numel() / query.size(-1);
128
  int rot_dim = cos_sin_cache.size(1);
129
130
  int num_heads = query.size(-1) / head_size;
  int num_kv_heads = key.size(-1) / head_size;
131
132
  int64_t query_stride = query.stride(-2);
  int64_t key_stride = key.stride(-2);
133
134

  dim3 grid(num_tokens);
135
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
136
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
137
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
138
  VLLM_DISPATCH_FLOATING_TYPES(
139
    query.scalar_type(),
140
    "rotary_embedding",
141
    [&] {
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
      if (is_neox) {
        vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      } else {
        vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      }
167
168
    });
}
Terry's avatar
Terry committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
  int head_size,
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
  bool is_neox,
  int rot_dim,
  torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) {
  int64_t num_tokens = cos_sin_cache_offsets.size(0);
  int num_heads = query.size(-1) / head_size;
  int num_kv_heads = key.size(-1) / head_size;
  int64_t query_stride = query.stride(-2);
  int64_t key_stride = key.stride(-2);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
    query.scalar_type(),
    "rotary_embedding",
    [&] {
      if (is_neox) {
        vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          cos_sin_cache_offsets.data_ptr<int64_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      } else {
        vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          cos_sin_cache_offsets.data_ptr<int64_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      }
    });
}