rope.cu 5.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*
 * Copyright (c) 2024 by FlashInfer team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

17
#include "pos_enc.cuh"
18
19
20
21
22
23
24
25
26
27
28
29
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void apply_rope_pos_ids_cos_sin_cache(
    at::Tensor q,
    at::Tensor k,
    at::Tensor q_rope,
    at::Tensor k_rope,
    at::Tensor cos_sin_cache,
    at::Tensor pos_ids,
    bool interleave,
30
31
32
33
34
    int64_t cuda_stream,
    const std::optional<at::Tensor>& v,
    const std::optional<at::Tensor>& k_buffer,
    const std::optional<at::Tensor>& v_buffer,
    const std::optional<at::Tensor>& kv_cache_loc) {
35
36
  CHECK_LAST_DIM_CONTIGUOUS(q);
  CHECK_LAST_DIM_CONTIGUOUS(k);
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

  const bool save_kv_cache = v.has_value();
  if (save_kv_cache) {
    TORCH_CHECK(v.has_value());
    TORCH_CHECK(k_buffer.has_value());
    TORCH_CHECK(v_buffer.has_value());
    TORCH_CHECK(kv_cache_loc.has_value());
    CHECK_LAST_DIM_CONTIGUOUS(v.value());
    CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
    CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
    CHECK_DIM(3, k_buffer.value());      // k_buffer: (nnz, H_K, D)
    CHECK_DIM(3, v_buffer.value());      // v_buffer: (nnz, H_V, D)
    CHECK_DIM(3, v.value());             // v: (nnz, H_V, D)
    CHECK_DIM(1, kv_cache_loc.value());  // v: (n)
    CHECK_INPUT(kv_cache_loc.value());
  }
  size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
  size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
  size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
  size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
  size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
  size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
  auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;

61
62
63
64
65
66
67
68
  CHECK_INPUT(cos_sin_cache);
  CHECK_INPUT(pos_ids);
  auto device = q.device();
  CHECK_EQ(k.device(), device);
  CHECK_EQ(cos_sin_cache.device(), device);
  CHECK_EQ(pos_ids.device(), device);
  CHECK_DIM(3, q);  // q: (nnz, H_Q, D)
  CHECK_DIM(3, k);  // k: (nnz, H_K, D)
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
  // cos_sin_cache: (max_seq_len, R)
  // First half of R is cos, second half is sin
  CHECK_DIM(2, cos_sin_cache);
  CHECK_EQ(q.size(0), k.size(0));
  CHECK_EQ(q.size(2), k.size(2));
  unsigned int rotary_dim = cos_sin_cache.size(1);
  unsigned int num_qo_heads = q.size(1);
  unsigned int num_kv_heads = k.size(1);
  unsigned int head_dim = q.size(2);
  unsigned int nnz = q.size(0);
  size_t q_stride_n = q.stride(0);
  size_t q_stride_h = q.stride(1);
  size_t k_stride_n = k.stride(0);
  size_t k_stride_h = k.stride(1);
84

85
86
87
88
89
90
91
  size_t q_rope_stride_n = q_rope.stride(0);
  size_t q_rope_stride_h = q_rope.stride(1);
  size_t k_rope_stride_n = k_rope.stride(0);
  size_t k_rope_stride_h = k_rope.stride(1);

  cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
    // to avoid changing original code path; but this branch is feature-complete and should switch to this later
    if (save_kv_cache) {
      cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
          static_cast<c_type*>(q.data_ptr()),
          static_cast<c_type*>(k.data_ptr()),
          save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
          static_cast<c_type*>(q_rope.data_ptr()),
          static_cast<c_type*>(k_rope.data_ptr()),
          save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
          save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
          static_cast<float*>(cos_sin_cache.data_ptr()),
          static_cast<int64_t*>(pos_ids.data_ptr()),
          nnz,
          num_qo_heads,
          num_kv_heads,
          rotary_dim,
          head_dim,
          q_stride_n,
          q_stride_h,
          k_stride_n,
          k_stride_h,
          v_stride_n,
          v_stride_h,
          q_rope_stride_n,
          q_rope_stride_h,
          k_rope_stride_n,
          k_rope_stride_h,
          k_buffer_stride_n,
          k_buffer_stride_h,
          v_buffer_stride_n,
          v_buffer_stride_h,
          kv_cache_loc_ptr,
          interleave,
          save_kv_cache,
          stream);
      TORCH_CHECK(
          status == cudaSuccess,
          "BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
              std::string(cudaGetErrorString(status)));
    } else {
      cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
          static_cast<c_type*>(q.data_ptr()),
          static_cast<c_type*>(k.data_ptr()),
          static_cast<c_type*>(q_rope.data_ptr()),
          static_cast<c_type*>(k_rope.data_ptr()),
          static_cast<float*>(cos_sin_cache.data_ptr()),
          static_cast<int64_t*>(pos_ids.data_ptr()),
          nnz,
          num_qo_heads,
          num_kv_heads,
          rotary_dim,
          head_dim,
          q_stride_n,
          q_stride_h,
          k_stride_n,
          k_stride_h,
          q_rope_stride_n,
          q_rope_stride_h,
          k_rope_stride_n,
          k_rope_stride_h,
          interleave,
          stream);
      TORCH_CHECK(
          status == cudaSuccess,
          "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
    }
159
160
161
    return true;
  });
}