pos_encoding_kernels.cu 8.11 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4

5
#include "cuda_compat.h"
6
7
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
8
namespace vllm {
9

10
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
11
inline __device__ void apply_token_rotary_embedding(
12
13
14
    scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
    const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
    const bool inverse) {
15
  int x_index, y_index;
16
  float cos_f, sin_f;
17
18
19
  if (IS_NEOX) {
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
20
21
    cos_f = VLLM_LDG(cos_ptr + x_index);
    sin_f = VLLM_LDG(sin_ptr + x_index);
22
23
24
  } else {
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
25
26
    cos_f = VLLM_LDG(cos_ptr + x_index / 2);
    sin_f = VLLM_LDG(sin_ptr + x_index / 2);
27
  }
28
29
30
31
32
33
34
  if (inverse) {
    sin_f = -sin_f;
  }
  const float x_f = static_cast<float>(arr[x_index]);
  const float y_f = static_cast<float>(arr[y_index]);
  arr[x_index] = static_cast<scalar_t>(x_f * cos_f - y_f * sin_f);
  arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
35
36
}

37
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
38
inline __device__ void apply_rotary_embedding(
39
40
41
    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
42
43
    scalar_t* __restrict__ key,    // nullptr or
                                   // [batch_size, seq_len, num_kv_heads,
44
45
                                   // head_size] or [num_tokens, num_kv_heads,
                                   // head_size]
46
    const float* cache_ptr, const int head_size, const int num_heads,
47
    const int num_kv_heads, const int rot_dim, const int token_idx,
48
    const int64_t query_stride, const int64_t key_stride,
49
50
    const int64_t head_stride, const int64_t rope_dim_offset,
    const bool inverse) {
51
  const int embed_dim = rot_dim / 2;
52
53
  const float* cos_ptr = cache_ptr;
  const float* sin_ptr = cache_ptr + embed_dim;
54

Zhuohan Li's avatar
Zhuohan Li committed
55
56
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
57
    const int head_idx = i / embed_dim;
58
    const int64_t token_head =
59
        token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
Woosuk Kwon's avatar
Woosuk Kwon committed
60
    const int rot_offset = i % embed_dim;
61
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
62
        query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
Zhuohan Li's avatar
Zhuohan Li committed
63
64
  }

65
66
67
68
  if (key != nullptr) {
    const int nk = num_kv_heads * embed_dim;
    for (int i = threadIdx.x; i < nk; i += blockDim.x) {
      const int head_idx = i / embed_dim;
69
      const int64_t token_head =
70
          token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
71
72
      const int rot_offset = i % embed_dim;
      apply_token_rotary_embedding<scalar_t, IS_NEOX>(
73
          key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
74
    }
75
76
77
  }
}

78
template <typename scalar_t, bool IS_NEOX>
Terry's avatar
Terry committed
79
__global__ void rotary_embedding_kernel(
80
81
82
83
84
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
                                   // head_size] or [num_tokens, num_heads,
                                   // head_size]
85
86
    scalar_t* __restrict__ key,  // nullptr or
                                 // [batch_size, seq_len, num_kv_heads,
87
88
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
89
    const float* __restrict__ cos_sin_cache,  // [max_position, rot_dim] fp32
90
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
91
    const int64_t head_stride, const int num_heads, const int num_kv_heads,
92
    const int head_size, const int64_t rope_dim_offset, const bool inverse) {
Terry's avatar
Terry committed
93
94
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
95
  const float* cache_ptr = cos_sin_cache + pos * rot_dim;
Terry's avatar
Terry committed
96

97
98
  apply_rotary_embedding<scalar_t, IS_NEOX>(
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
99
100
      token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
      inverse);
Terry's avatar
Terry committed
101
102
}

103
}  // namespace vllm
104

105
void rotary_embedding(
106
107
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
108
109
110
                           // [num_tokens, num_heads * head_size] or
                           // [batch_size, seq_len, num_heads, head_size] or
                           // [num_tokens, num_heads, head_size]
111
112
113
114
115
116
    std::optional<torch::Tensor> key,
    // null or
    // [batch_size, seq_len, num_kv_heads * head_size] or
    // [num_tokens, num_kv_heads * head_size] or
    // [batch_size, seq_len, num_heads, head_size] or
    // [num_tokens, num_heads, head_size]
117
    int64_t head_size,
118
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
119
    bool is_neox, int64_t rope_dim_offset, bool inverse) {
120
121
122
123
  // num_tokens = batch_size * seq_len
  int64_t num_tokens = positions.numel();
  int positions_ndim = positions.dim();

124
  // Make sure num_tokens dim is consistent across positions, query, and key
125
126
127
128
  TORCH_CHECK(
      positions_ndim == 1 || positions_ndim == 2,
      "positions must have shape [num_tokens] or [batch_size, seq_len]");
  if (positions_ndim == 1) {
129
130
131
    TORCH_CHECK(query.size(0) == positions.size(0) &&
                    (!key.has_value() || key->size(0) == positions.size(0)),
                "query, key and positions must have the same number of tokens");
132
133
134
135
  }
  if (positions_ndim == 2) {
    TORCH_CHECK(
        query.size(0) == positions.size(0) &&
136
            (!key.has_value() || key->size(0) == positions.size(0)) &&
137
            query.size(1) == positions.size(1) &&
138
            (!key.has_value() || key->size(1) == positions.size(1)),
139
140
141
142
143
144
        "query, key and positions must have the same batch_size and seq_len");
  }

  // Make sure head_size is valid for query and key
  // hidden_size = num_heads * head_size
  int query_hidden_size = query.numel() / num_tokens;
145
  int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
146
147
148
149
150
  TORCH_CHECK(query_hidden_size % head_size == 0);
  TORCH_CHECK(key_hidden_size % head_size == 0);

  // Make sure query and key have consistent number of heads
  int num_heads = query_hidden_size / head_size;
151
  int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
152
153
  TORCH_CHECK(num_heads % num_kv_heads == 0);

154
  int rot_dim = cos_sin_cache.size(1);
155
156
  int seq_dim_idx = positions_ndim - 1;
  int64_t query_stride = query.stride(seq_dim_idx);
157
  int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
158
159

  TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size);
160
161
162
163
164
165
  // Determine head stride: for [*, heads, head_size] use stride of last dim;
  // for flat [*, heads*head_size], heads blocks are contiguous of size
  // head_size
  int query_ndim = query.dim();
  int64_t head_stride =
      (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
166
167

  dim3 grid(num_tokens);
168
  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
169
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
170
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
171
  auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
172
173
174
175
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
    if (is_neox) {
      vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
176
          key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
177
178
179
          cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
          head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
          inverse);
180
181
182
183
    } else {
      vllm::rotary_embedding_kernel<scalar_t, false>
          <<<grid, block, 0, stream>>>(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
184
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
185
186
187
              cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
              head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
              inverse);
188
189
    }
  });
190
}