cutlass_mla_kernel.cu 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>

#include <cute/tensor.hpp>
25
#include <iostream>
26
27
28

#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
29
#include "utils.h"
30

31
32
33
34
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void cutlass_mla_decode(
    torch::Tensor const& out,
35
36
    torch::Tensor const& q_nope,
    torch::Tensor const& q_pe,
37
38
39
    torch::Tensor const& kv_c_and_k_pe_cache,
    torch::Tensor const& seq_lens,
    torch::Tensor const& page_table,
40
41
    torch::Tensor const& workspace,
    int64_t num_kv_splits) {
42
43
  TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
44
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
45
46
47
  TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
}
#else
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
#define CUTLASS_CHECK(status)                                                       \
  {                                                                                 \
    cutlass::Status error = status;                                                 \
    TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
  }

using namespace cute;
using namespace cutlass::fmha::kernel;

template <bool v>
struct IsPersistent {
  static const bool value = v;
};

63
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
struct MlaSm100 {
  using Element = T;
  using ElementAcc = float;
  using ElementOut = T;

  using TileShape = Shape<_128, _128, Shape<_512, _64>>;
  using TileShapeH = cute::tuple_element_t<0, TileShape>;
  using TileShapeD = cute::tuple_element_t<2, TileShape>;

  // H K (D_latent D_rope) B
  using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;

  using StrideQ = cute::tuple<int64_t, _1, int64_t>;  // H D B
  using StrideK = cute::tuple<int64_t, _1, int64_t>;  // K D B
  using StrideO = StrideK;                            // H D B
  using StrideLSE = cute::tuple<_1, int>;             // H B

  using TileScheduler =
      std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;

  using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
      TileShape,
      Element,
      ElementAcc,
      ElementOut,
      ElementAcc,
      TileScheduler,
91
      /*kIsCpAsync=*/!IsPaged128>;
92
93
94
95
96
97
  using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};

template <typename T>
typename T::Fmha::Arguments args_from_options(
    at::Tensor const& out,
98
99
    at::Tensor const& q_nope,
    at::Tensor const& q_pe,
100
101
    at::Tensor const& kv_c_and_k_pe_cache,
    at::Tensor const& seq_lens,
102
    at::Tensor const& page_table,
103
    double sm_scale,
104
    int64_t num_kv_splits) {
105
  cutlass::KernelHardwareInfo hw_info;
106
  hw_info.device_id = q_nope.device().index();
107
108
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

109
110
111
112
  int batches = q_nope.size(0);
  int page_count_per_seq = page_table.size(1);
  int page_count_total = kv_c_and_k_pe_cache.size(0);
  int page_size = kv_c_and_k_pe_cache.size(1);
113
114
115
116
117
118
119
120
  int max_seq_len = page_size * page_count_per_seq;
  using TileShapeH = typename T::TileShapeH;
  using TileShapeD = typename T::TileShapeD;
  auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);

  auto [H, K, D, B] = problem_shape;
  auto [D_latent, D_rope] = D;

121
  float scale = float(sm_scale);
122
123
124
125
126
127

  using StrideQ = typename T::StrideQ;
  using StrideK = typename T::StrideK;
  using StrideO = typename T::StrideO;
  using StrideLSE = typename T::StrideLSE;

128
129
130
131
132
  StrideQ stride_Q_nope = cute::make_tuple(
      static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
  StrideQ stride_Q_pe = cute::make_tuple(
      static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));

133
134
135
136
137
138
139
140
141
  StrideK stride_C = cute::make_tuple(
      static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
  StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
  StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
  StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));

  using Element = typename T::Element;
  using ElementOut = typename T::ElementOut;
  using ElementAcc = typename T::ElementAcc;
142
143
  auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
  auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
144
145
146
147
  auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
  typename T::Fmha::Arguments arguments{
      problem_shape,
      {scale,
148
149
150
151
       Q_nope_ptr,
       stride_Q_nope,
       Q_pe_ptr,
       stride_Q_pe,
152
153
154
155
156
157
158
159
160
161
162
       C_ptr,
       stride_C,
       C_ptr + D_latent,
       stride_C,
       static_cast<int*>(seq_lens.data_ptr()),
       static_cast<int*>(page_table.data_ptr()),
       stride_PT,
       page_count_total,
       page_size},
      {static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
      hw_info,
163
164
165
      // TODO(trevor-m): Change split_kv back to -1 when
      // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
      // perform worse with larger context length and smaller batch sizes.
166
      static_cast<int>(num_kv_splits), // split_kv
167
      nullptr,       // is_var_split_kv
168
169
170
171
172
173
174
175
176
  };
  // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
  // split_kv automatically based on batch size and sequence length to balance
  // workload across available SMs. Consider using var_split_kv for manual
  // control if needed.
  T::Fmha::set_split_kv(arguments);
  return arguments;
}

177
template <typename Element, bool IsPaged128, typename PersistenceOption>
178
179
void runMla(
    at::Tensor const& out,
180
181
    at::Tensor const& q_nope,
    at::Tensor const& q_pe,
182
183
184
185
    at::Tensor const& kv_c_and_k_pe_cache,
    at::Tensor const& seq_lens,
    at::Tensor const& page_table,
    at::Tensor const& workspace,
186
    double sm_scale,
187
    int64_t num_kv_splits,
188
    cudaStream_t stream) {
189
  using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
190
  typename MlaSm100Type::Fmha fmha;
191
  auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
192
193
194
195
196
197
198
199

  CUTLASS_CHECK(fmha.can_implement(arguments));

  CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));

  CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}

200
201
202
203
204
205
206
207
208
209
210
#define DISPATCH_BOOL(expr, const_expr, ...) \
  [&]() -> bool {                            \
    if (expr) {                              \
      constexpr bool const_expr = true;      \
      return __VA_ARGS__();                  \
    } else {                                 \
      constexpr bool const_expr = false;     \
      return __VA_ARGS__();                  \
    }                                        \
  }()

211
212
void cutlass_mla_decode(
    torch::Tensor const& out,
213
214
    torch::Tensor const& q_nope,
    torch::Tensor const& q_pe,
215
216
217
    torch::Tensor const& kv_c_and_k_pe_cache,
    torch::Tensor const& seq_lens,
    torch::Tensor const& page_table,
218
    torch::Tensor const& workspace,
219
    double sm_scale,
220
    int64_t num_kv_splits) {
221
222
223
224
  auto sm_version = getSMVersion();
  // On SM103a, half of the accuracy tests are failing.
  TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version);

225
226
227
  auto in_dtype = q_nope.dtype();
  at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
228
  const int page_size = kv_c_and_k_pe_cache.size(1);
229
230
231
232
233
234
235
236

  // NOTE(alcanderian): IsPersistent has bug with manual split_kv.
  // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
  // Maybe per batch split kv will fix this.
  DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
    DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
      if (in_dtype == at::ScalarType::Half) {
        runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
237
          out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
238
239
      } else if (in_dtype == at::ScalarType::BFloat16) {
        runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
240
          out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
241
242
      } else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
        runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
243
          out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
244
245
246
247
248
249
250
      } else {
        TORCH_CHECK(false, "Unsupported input data type of MLA");
      }
      return true;
    });
    return true;
  });
251
252
}

253
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
254
255
  // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
  // which are float, so Element type here doesn't matter.
256
  using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
257
258
259
260
261
262
263
264
265
266

  // Get split kv. Requires problem shape and sm_count only.
  typename MlaSm100Type::Fmha::Arguments arguments;
  using TileShapeH = typename MlaSm100Type::TileShapeH;
  using TileShapeD = typename MlaSm100Type::TileShapeD;
  arguments.problem_shape =
      cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
  // Assumes device 0 when getting sm_count.
  arguments.hw_info.sm_count =
      sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
267
  arguments.split_kv = static_cast<int>(num_kv_splits);
268
269
270
271
  MlaSm100Type::Fmha::set_split_kv(arguments);

  return MlaSm100Type::Fmha::get_workspace_size(arguments);
}
272
273

#endif
274
// clang-format on