pos_encoding_kernels.cu 8.94 KB
Newer Older
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4

5
#include "cuda_compat.h"
6
7
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
8
namespace vllm {
9

10
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
11
inline __device__ void apply_token_rotary_embedding(
12
13
    scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
    const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
14
15
16
17
18
19
  int x_index, y_index;
  scalar_t cos, sin;
  if (IS_NEOX) {
    // GPT-NeoX style rotary embedding.
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
20
21
    cos = VLLM_LDG(cos_ptr + x_index);
    sin = VLLM_LDG(sin_ptr + x_index);
22
23
24
25
  } else {
    // GPT-J style rotary embedding.
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
26
27
    cos = VLLM_LDG(cos_ptr + x_index / 2);
    sin = VLLM_LDG(sin_ptr + x_index / 2);
28
29
30
31
32
33
34
35
  }

  const scalar_t x = arr[x_index];
  const scalar_t y = arr[y_index];
  arr[x_index] = x * cos - y * sin;
  arr[y_index] = y * cos + x * sin;
}

36
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
37
inline __device__ void apply_rotary_embedding(
38
39
40
41
42
43
44
45
46
    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
    scalar_t* __restrict__ key,    // [batch_size, seq_len, num_kv_heads,
                                   // head_size] or [num_tokens, num_kv_heads,
                                   // head_size]
    const scalar_t* cache_ptr, const int head_size, const int num_heads,
    const int num_kv_heads, const int rot_dim, const int token_idx,
    const int64_t query_stride, const int64_t key_stride) {
47
  const int embed_dim = rot_dim / 2;
48
49
50
  const scalar_t* cos_ptr = cache_ptr;
  const scalar_t* sin_ptr = cache_ptr + embed_dim;

Zhuohan Li's avatar
Zhuohan Li committed
51
52
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
53
    const int head_idx = i / embed_dim;
54
    const int64_t token_head = token_idx * query_stride + head_idx * head_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
55
    const int rot_offset = i % embed_dim;
56
57
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
        query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
Zhuohan Li's avatar
Zhuohan Li committed
58
59
60
61
62
  }

  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
63
    const int64_t token_head = token_idx * key_stride + head_idx * head_size;
Zhuohan Li's avatar
Zhuohan Li committed
64
    const int rot_offset = i % embed_dim;
65
66
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
        key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
67
68
69
  }
}

70
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
71
__global__ void rotary_embedding_kernel(
72
73
74
75
76
77
78
79
80
81
82
83
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size) {
Terry's avatar
Terry committed
84
85
86
87
88
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

89
90
91
  apply_rotary_embedding<scalar_t, IS_NEOX>(
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
      token_idx, query_stride, key_stride);
Terry's avatar
Terry committed
92
93
}

94
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
95
__global__ void batched_rotary_embedding_kernel(
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
    const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len]
                                                        // or [num_tokens]
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size) {
Terry's avatar
Terry committed
110
111
112
113
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
114
115
  const scalar_t* cache_ptr =
      cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
Terry's avatar
Terry committed
116

117
118
119
  apply_rotary_embedding<scalar_t, IS_NEOX>(
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
      token_idx, query_stride, key_stride);
Terry's avatar
Terry committed
120
121
}

122
}  // namespace vllm
123

124
void rotary_embedding(
125
126
127
128
129
130
131
132
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
                           // [num_tokens, num_heads * head_size]
    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
                           // [num_tokens, num_kv_heads * head_size]
    int head_size,
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
    bool is_neox) {
Antoni Baum's avatar
Antoni Baum committed
133
  int64_t num_tokens = query.numel() / query.size(-1);
134
  int rot_dim = cos_sin_cache.size(1);
135
136
  int num_heads = query.size(-1) / head_size;
  int num_kv_heads = key.size(-1) / head_size;
137
138
  int64_t query_stride = query.stride(-2);
  int64_t key_stride = key.stride(-2);
139
140

  dim3 grid(num_tokens);
141
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
142
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
143
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
    if (is_neox) {
      vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
          query_stride, key_stride, num_heads, num_kv_heads, head_size);
    } else {
      vllm::rotary_embedding_kernel<scalar_t, false>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
              rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
              head_size);
    }
  });
159
}
Terry's avatar
Terry committed
160
161
162
163
164
165

/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
166
167
168
169
170
171
172
173
174
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
                           // [num_tokens, num_heads * head_size]
    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
                           // [num_tokens, num_kv_heads * head_size]
    int head_size,
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
    bool is_neox, int rot_dim,
    torch::Tensor& cos_sin_cache_offsets  // [num_tokens]
Terry's avatar
Terry committed
175
176
177
178
179
180
181
182
183
184
185
) {
  int64_t num_tokens = cos_sin_cache_offsets.size(0);
  int num_heads = query.size(-1) / head_size;
  int num_kv_heads = key.size(-1) / head_size;
  int64_t query_stride = query.stride(-2);
  int64_t key_stride = key.stride(-2);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
    if (is_neox) {
      vllm::batched_rotary_embedding_kernel<scalar_t, true>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size);
    } else {
      vllm::batched_rotary_embedding_kernel<scalar_t, false>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size);
    }
  });
Terry's avatar
Terry committed
203
}