pos_encoding_kernels.cu 4.08 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
4
5
6
7
8
9
10
11
/*

Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu

*/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"

Casper Hansen's avatar
Casper Hansen committed
12
13
14
15
16
17
18
19
20
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
  AT_DISPATCH_SWITCH(                                             \
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

Haotian Tang's avatar
Haotian Tang committed
21
22
23
24
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
  const int64_t* __restrict__ positions,        // [num_tokens]
  scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
Casper Hansen's avatar
Casper Hansen committed
25
  scalar_t* __restrict__ key,                   // [num_tokens, num_kv_heads, head_size]
Haotian Tang's avatar
Haotian Tang committed
26
27
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
  const int rot_dim,
Casper Hansen's avatar
Casper Hansen committed
28
29
  const int query_stride,
  const int key_stride,
Haotian Tang's avatar
Haotian Tang committed
30
  const int num_heads,
Casper Hansen's avatar
Casper Hansen committed
31
  const int num_kv_heads,
Haotian Tang's avatar
Haotian Tang committed
32
33
34
35
36
37
38
  const int head_size) {
  // Each thread block is responsible for one token.
  const int token_idx = blockIdx.x;
  int64_t pos = positions[token_idx];
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

  const int embed_dim = rot_dim / 2;
Casper Hansen's avatar
Casper Hansen committed
39
40
  const int nq = num_heads * embed_dim;
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
Haotian Tang's avatar
Haotian Tang committed
41
    const int head_idx = i / embed_dim;
Casper Hansen's avatar
Casper Hansen committed
42
    const int token_head = token_idx * query_stride + head_idx * head_size;
Haotian Tang's avatar
Haotian Tang committed
43
44
45
46
47

    const int rot_offset = i % embed_dim;
    const int x_index = rot_offset;
    const int y_index = embed_dim + rot_offset;

Casper Hansen's avatar
Casper Hansen committed
48
49
    const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
    const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
Haotian Tang's avatar
Haotian Tang committed
50
51
52
53
54
55
56
57

    const scalar_t cos = __ldg(cache_ptr + x_index);
    const scalar_t sin = __ldg(cache_ptr + y_index);

    const scalar_t q_x = query[token_head + x_index];
    const scalar_t q_y = query[token_head + y_index];
    query[out_x] = q_x * cos - q_y * sin;
    query[out_y] = q_y * cos + q_x * sin;
Casper Hansen's avatar
Casper Hansen committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  }

  const int nk = num_kv_heads * embed_dim;
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
    const int head_idx = i / embed_dim;
    const int token_head = token_idx * key_stride + head_idx * head_size;

    const int rot_offset = i % embed_dim;
    const int x_index = rot_offset;
    const int y_index = embed_dim + rot_offset;

    const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
    const int out_y = token_idx * key_stride + head_idx * head_size + y_index;

    const scalar_t cos = __ldg(cache_ptr + x_index);
    const scalar_t sin = __ldg(cache_ptr + y_index);
Haotian Tang's avatar
Haotian Tang committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    const scalar_t k_x = key[token_head + x_index];
    const scalar_t k_y = key[token_head + y_index];
    key[out_x] = k_x * cos - k_y * sin;
    key[out_y] = k_y * cos + k_x * sin;
  }
}

void rotary_embedding_neox(
  torch::Tensor& positions,         // [b, num_tokens]
  torch::Tensor& query,             // [b, num_tokens, 1, num_heads, head_size]
  torch::Tensor& key,               // [b, num_tokens, 1, num_heads, head_size]
  int head_size,
  torch::Tensor& cos_sin_cache)     // [max_position, rot_dim]
{
Casper Hansen's avatar
Casper Hansen committed
89
  int num_tokens = query.size(0);
Haotian Tang's avatar
Haotian Tang committed
90
  int rot_dim = cos_sin_cache.size(1);
Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
  int num_heads = query.size(1) / head_size;
  int num_kv_heads = key.size(1) / head_size;
  int query_stride = query.stride(0);
  int key_stride = key.stride(0);
Haotian Tang's avatar
Haotian Tang committed
95
96
97
98

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Casper Hansen's avatar
Casper Hansen committed
99
  VLLM_DISPATCH_FLOATING_TYPES(
Haotian Tang's avatar
Haotian Tang committed
100
101
102
103
104
105
106
107
108
    query.scalar_type(),
    "rotary_embedding_neox",
    [&] {
      rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
        positions.data_ptr<int64_t>(),
        query.data_ptr<scalar_t>(),
        key.data_ptr<scalar_t>(),
        cos_sin_cache.data_ptr<scalar_t>(),
        rot_dim,
Casper Hansen's avatar
Casper Hansen committed
109
110
        query_stride,
        key_stride,
Haotian Tang's avatar
Haotian Tang committed
111
        num_heads,
Casper Hansen's avatar
Casper Hansen committed
112
        num_kv_heads,
Haotian Tang's avatar
Haotian Tang committed
113
114
        head_size);
    });
Casper Hansen's avatar
Casper Hansen committed
115

Haotian Tang's avatar
Haotian Tang committed
116
117
}