pos_encoding_kernels.cu 3.57 KB
Newer Older
1
2
3
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

4
5
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
6
namespace vllm {
7
8
9
10

template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
  const int64_t* __restrict__ positions,        // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
11
  scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
12
  scalar_t* __restrict__ key,                   // [num_tokens, num_kv_heads, head_size]
13
14
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
  const int rot_dim,
Zhuohan Li's avatar
Zhuohan Li committed
15
16
  const int query_stride,
  const int key_stride,
17
  const int num_heads,
Zhuohan Li's avatar
Zhuohan Li committed
18
  const int num_kv_heads,
19
20
21
22
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
23
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
24

25
  const int embed_dim = rot_dim / 2;
Zhuohan Li's avatar
Zhuohan Li committed
26
27
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    const int head_idx = i / embed_dim;
Zhuohan Li's avatar
Zhuohan Li committed
29
    const int token_head = token_idx * query_stride + head_idx * head_size;
30

Woosuk Kwon's avatar
Woosuk Kwon committed
31
    const int rot_offset = i % embed_dim;
32
33
34
    const int x_index = rot_offset;
    const int y_index = embed_dim + rot_offset;

Zhuohan Li's avatar
Zhuohan Li committed
35
36
    const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
    const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
Woosuk Kwon's avatar
Woosuk Kwon committed
37

38
39
40
    const scalar_t cos = __ldg(cache_ptr + x_index);
    const scalar_t sin = __ldg(cache_ptr + y_index);

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44
    const scalar_t q_x = query[token_head + x_index];
    const scalar_t q_y = query[token_head + y_index];
    query[out_x] = q_x * cos - q_y * sin;
    query[out_y] = q_y * cos + q_x * sin;
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  }

  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
    const int token_head = token_idx * key_stride + head_idx * head_size;

    const int rot_offset = i % embed_dim;
    const int x_index = rot_offset;
    const int y_index = embed_dim + rot_offset;

    const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
    const int out_y = token_idx * key_stride + head_idx * head_size + y_index;

    const scalar_t cos = __ldg(cache_ptr + x_index);
    const scalar_t sin = __ldg(cache_ptr + y_index);
61

Zhuohan Li's avatar
Zhuohan Li committed
62
63
64
65
    const scalar_t k_x = key[token_head + x_index];
    const scalar_t k_y = key[token_head + y_index];
    key[out_x] = k_x * cos - k_y * sin;
    key[out_y] = k_y * cos + k_x * sin;
66
67
68
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
69
} // namespace vllm
70
71
72
73

void rotary_embedding_neox(
  torch::Tensor& positions,         // [num_tokens]
  torch::Tensor& query,             // [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
74
  torch::Tensor& key,               // [num_tokens, num_kv_heads * head_size]
75
76
  int head_size,
  torch::Tensor& cos_sin_cache)     // [max_position, rot_dim]
77
78
{
  int num_tokens = query.size(0);
79
  int rot_dim = cos_sin_cache.size(1);
80
  int num_heads = query.size(1) / head_size;
Zhuohan Li's avatar
Zhuohan Li committed
81
  int num_kv_heads = key.size(1) / head_size;
Zhuohan Li's avatar
Zhuohan Li committed
82
83
  int query_stride = query.stride(0);
  int key_stride = key.stride(0);
84
85

  dim3 grid(num_tokens);
86
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
87
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
88
  VLLM_DISPATCH_FLOATING_TYPES(
89
90
91
    query.scalar_type(),
    "rotary_embedding_neox",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
92
      vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
93
94
95
96
        positions.data_ptr<int64_t>(),
        query.data_ptr<scalar_t>(),
        key.data_ptr<scalar_t>(),
        cos_sin_cache.data_ptr<scalar_t>(),
97
        rot_dim,
Zhuohan Li's avatar
Zhuohan Li committed
98
99
        query_stride,
        key_stride,
100
        num_heads,
Zhuohan Li's avatar
Zhuohan Li committed
101
        num_kv_heads,
102
103
104
        head_size);
    });
}