pos_encoding_kernels.cu 4.64 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
4
5
6
7
8
9
10
11
/*

Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu

*/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"

Casper Hansen's avatar
Casper Hansen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
  AT_DISPATCH_SWITCH(                                             \
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
  scalar_t* __restrict__ arr,
  const scalar_t* __restrict__ cos_ptr,
  const scalar_t* __restrict__ sin_ptr,
  int rot_offset,
  int embed_dim)
{
  int x_index, y_index;
  scalar_t cos, sin;
  if (IS_NEOX) {
    // GPT-NeoX style rotary embedding.
    x_index = rot_offset;
    y_index = embed_dim + rot_offset;
    cos = __ldg(cos_ptr + x_index);
    sin = __ldg(sin_ptr + x_index);
  } else {
    // GPT-J style rotary embedding.
    x_index = 2 * rot_offset;
    y_index = 2 * rot_offset + 1;
    cos = __ldg(cos_ptr + x_index / 2);
    sin = __ldg(sin_ptr + x_index / 2);
  }

  const scalar_t x = arr[x_index];
  const scalar_t y = arr[y_index];
  arr[x_index] = x * cos - y * sin;
  arr[y_index] = y * cos + x * sin;
}

template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
Haotian Tang's avatar
Haotian Tang committed
53
54
  const int64_t* __restrict__ positions,        // [num_tokens]
  scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
Casper Hansen's avatar
Casper Hansen committed
55
  scalar_t* __restrict__ key,                   // [num_tokens, num_kv_heads, head_size]
Haotian Tang's avatar
Haotian Tang committed
56
57
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
  const int rot_dim,
Casper Hansen's avatar
Casper Hansen committed
58
59
  const int query_stride,
  const int key_stride,
Haotian Tang's avatar
Haotian Tang committed
60
  const int num_heads,
Casper Hansen's avatar
Casper Hansen committed
61
  const int num_kv_heads,
Haotian Tang's avatar
Haotian Tang committed
62
63
64
65
66
67
68
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

  const int embed_dim = rot_dim / 2;
Casper Hansen's avatar
Casper Hansen committed
69
70
  const scalar_t* cos_ptr = cache_ptr;
  const scalar_t* sin_ptr = cache_ptr + embed_dim;
Haotian Tang's avatar
Haotian Tang committed
71

Casper Hansen's avatar
Casper Hansen committed
72
73
74
75
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
    const int head_idx = i / embed_dim;
    const int token_head = token_idx * query_stride + head_idx * head_size;
Haotian Tang's avatar
Haotian Tang committed
76
    const int rot_offset = i % embed_dim;
Casper Hansen's avatar
Casper Hansen committed
77
78
79
    apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
                                              sin_ptr, rot_offset, embed_dim);
  }
Haotian Tang's avatar
Haotian Tang committed
80

Casper Hansen's avatar
Casper Hansen committed
81
82
83
84
85
86
87
  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
    const int token_head = token_idx * key_stride + head_idx * head_size;
    const int rot_offset = i % embed_dim;
    apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
                                              sin_ptr, rot_offset, embed_dim);
Haotian Tang's avatar
Haotian Tang committed
88
89
  }
}
Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
void rotary_embedding(
  torch::Tensor& positions,         // [num_tokens]
  torch::Tensor& query,             // [num_tokens, num_heads * head_size]
  torch::Tensor& key,               // [num_tokens, num_kv_heads * head_size]
Haotian Tang's avatar
Haotian Tang committed
94
  int head_size,
Casper Hansen's avatar
Casper Hansen committed
95
96
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
  bool is_neox) {
Casper Hansen's avatar
Casper Hansen committed
97
  int num_tokens = query.size(0);
Haotian Tang's avatar
Haotian Tang committed
98
  int rot_dim = cos_sin_cache.size(1);
Casper Hansen's avatar
Casper Hansen committed
99
100
  int num_heads = query.size(1) / head_size;
  int num_kv_heads = key.size(1) / head_size;
Casper Hansen's avatar
Casper Hansen committed
101
102
  int query_stride = query.stride(0);
  int key_stride = key.stride(0);
Haotian Tang's avatar
Haotian Tang committed
103
104
105
106

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Casper Hansen's avatar
Casper Hansen committed
107
  VLLM_DISPATCH_FLOATING_TYPES(
Haotian Tang's avatar
Haotian Tang committed
108
    query.scalar_type(),
Casper Hansen's avatar
Casper Hansen committed
109
    "rotary_embedding",
Haotian Tang's avatar
Haotian Tang committed
110
    [&] {
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
      if (is_neox) {
        rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      } else {
        rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
          positions.data_ptr<int64_t>(),
          query.data_ptr<scalar_t>(),
          key.data_ptr<scalar_t>(),
          cos_sin_cache.data_ptr<scalar_t>(),
          rot_dim,
          query_stride,
          key_stride,
          num_heads,
          num_kv_heads,
          head_size);
      }
Haotian Tang's avatar
Haotian Tang committed
136
    });
Casper Hansen's avatar
Casper Hansen committed
137
}