pos_encoding_kernels.cu 4.41 KB
Newer Older
1
2
3
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

4
5
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
6
namespace vllm {
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
  scalar_t* __restrict__ arr,
  const scalar_t* __restrict__ cos_ptr,
  const scalar_t* __restrict__ sin_ptr,
  int rot_offset,
  int embed_dim)
{
  int x_index, y_index;
  scalar_t cos, sin;
  if (IS_NEOX) {
    // GPT-NeoX style rotary embedding.
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
    cos = __ldg(cos_ptr + x_index);
    sin = __ldg(sin_ptr + x_index);
  } else {
    // GPT-J style rotary embedding.
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
    cos = __ldg(cos_ptr + x_index / 2);
    sin = __ldg(sin_ptr + x_index / 2);
  }

  const scalar_t x = arr[x_index];
  const scalar_t y = arr[y_index];
  arr[x_index] = x * cos - y * sin;
  arr[y_index] = y * cos + x * sin;
}

template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
40
41
42
  const int64_t* __restrict__ positions,        // [batch_size, seq_len] or [num_tokens]
  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
43
44
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
  const int rot_dim,
Zhuohan Li's avatar
Zhuohan Li committed
45
46
  const int query_stride,
  const int key_stride,
47
  const int num_heads,
Zhuohan Li's avatar
Zhuohan Li committed
48
  const int num_kv_heads,
49
50
51
52
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
53
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
54

55
  const int embed_dim = rot_dim / 2;
56
57
58
  const scalar_t* cos_ptr = cache_ptr;
  const scalar_t* sin_ptr = cache_ptr + embed_dim;

Zhuohan Li's avatar
Zhuohan Li committed
59
60
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
61
    const int head_idx = i / embed_dim;
Zhuohan Li's avatar
Zhuohan Li committed
62
    const int token_head = token_idx * query_stride + head_idx * head_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
63
    const int rot_offset = i % embed_dim;
64
65
    apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
                                              sin_ptr, rot_offset, embed_dim);
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
69
70
71
72
  }

  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
    const int token_head = token_idx * key_stride + head_idx * head_size;
    const int rot_offset = i % embed_dim;
73
74
    apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
                                              sin_ptr, rot_offset, embed_dim);
75
76
77
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
78
} // namespace vllm
79

80
void rotary_embedding(
81
82
83
  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
84
  int head_size,
85
86
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
  bool is_neox) {
Antoni Baum's avatar
Antoni Baum committed
87
  int64_t num_tokens = query.numel() / query.size(-1);
88
  int rot_dim = cos_sin_cache.size(1);
89
90
91
92
  int num_heads = query.size(-1) / head_size;
  int num_kv_heads = key.size(-1) / head_size;
  int query_stride = query.stride(-2);
  int key_stride = key.stride(-2);
93
94

  dim3 grid(num_tokens);
95
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
96
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
97
  VLLM_DISPATCH_FLOATING_TYPES(
98
    query.scalar_type(),
99
    "rotary_embedding",
100
    [&] {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
      if (is_neox) {
        vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      } else {
        vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      }
126
127
    });
}