fp8_gemm_kernel.cu 54.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h

#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>

#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>

51
#include "math.hpp"
52
53
54
55
56
#include "utils.h"

using namespace cute;

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
57
58
59
60
61
62
63
64
65
66
67
template <
    typename ElementType,
    typename OutElementType,
    typename AccumElementType,
    typename CtaShape,
    typename WarpShape,
    int Stages,
    bool WithBias,
    typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
    template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
    typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
struct DeviceGemmFp8RowwiseSm89 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");

  using ElementA = ElementType;
  using LayoutA = cutlass::layout::RowMajor;
  static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

  using ElementB = ElementType;
  using LayoutB = cutlass::layout::ColumnMajor;
  static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

  using ElementC = OutElementType;
  using LayoutC = cutlass::layout::RowMajor;
  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

  using ElementOutput = OutElementType;
  using LayoutOutput = cutlass::layout::RowMajor;
  static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;

  using ElementAccumulator = AccumElementType;
  using ElementComputeEpilogue = float;
  using ArchTag = cutlass::arch::Sm89;
  using OperatorClass = cutlass::arch::OpClassTensorOp;

  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
  // Number of epilogue stages in EVT
  static constexpr int EVTEpilogueStages = 1;

96
97
  using OutputTileThreadMap = cutlass::epilogue::threadblock::
      OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
98
99
100
101
102

  // Definition of EVT
  using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;

  using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
103
104
105
106
107
108
      cutlass::multiplies,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cutlass::FloatRoundStyle::round_to_nearest>;
  using bScaleSrc = cutlass::epilogue::threadblock::
      VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
109
110
  using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;

111
112
113
114
  using ComputeAScale = cutlass::epilogue::threadblock::
      VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
  using aScaleSrc = cutlass::epilogue::threadblock::
      VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
115
116
117
118
119
  using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;

  // With bias
  using biasSrc =
      cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
120
121
122
123
124
  using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
      cutlass::multiply_add,
      ElementC,
      ElementComputeEpilogue,
      cutlass::FloatRoundStyle::round_to_nearest>;
125
126
127
128
  using EpilogueAScaleWithBias =
      cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;

  using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
129
130
131
132
133
134
135
136
      OutputTileThreadMap,
      ElementC,
      cutlass::FloatRoundStyle::round_to_nearest,
      Stride<int64_t, _1, _0>>;
  using EpilogueStore = typename cutlass::platform::conditional<
      WithBias,
      cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
      cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
137
138
139
140

  using EpilogueOp = EpilogueStore;

  using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
      ElementA,
      LayoutA,
      cutlass::ComplexTransform::kNone,
      AlignmentA,
      ElementB,
      LayoutB,
      cutlass::ComplexTransform::kNone,
      AlignmentB,
      ElementC,
      LayoutC,
      AlignmentC,
      ElementAccumulator,
      ElementComputeEpilogue,
      OperatorClass,
      ArchTag,
      CtaShape,
      WarpShape,
      InstructionShape,
      EpilogueOp,
      ThreadblockSwizzle,
      Stages,
      FP8MathOperator,
      EVTEpilogueStages>::GemmKernel;
164
165
166
167
168

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

template <typename Gemm, bool WithBias>
169
170
171
172
173
174
175
typename Gemm::Arguments prepare_sm89_fp8_args(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
  using ElementT = typename Gemm::ElementA;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);

  int64_t lda = a.stride(0);
  int64_t ldb = b.stride(1);
  int64_t ldc = out.stride(0);

  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
  ElementOutput const* ptr_bias = nullptr;
  if constexpr (WithBias) {
    TORCH_CHECK(bias.has_value())
    ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
  }
  ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
  ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
  ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
  typename Gemm::Arguments args(
      cutlass::gemm::GemmUniversalMode::kGemm,  // Mode
      {m, n, k},                                // Problem size
      1,                                        // Split-k factor
      {},                                       // Epilogue args
      ptr_a,                                    // a pointer
      ptr_b,                                    // b pointer
      nullptr,                                  // c pointer (unused)
      nullptr,                                  // d pointer (unused)
      m * k,                                    // batch stride a (unused)
      n * k,                                    // batch stride b (unused)
      m * n,                                    // batch stride c (unused)
      m * n,                                    // batch stride d (unused)
      lda,                                      // stride a
      ldb,                                      // stride b
      ldc,                                      // stride c (unused)
      ldc);                                     // stride d (unused)
216
  if constexpr (WithBias) {
217
218
219
220
221
222
223
224
225
226
227
228
    args.epilogue = {
        {
            {
                {},  // Accumulator
                {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
                {}  // Multiplies
            },
            {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
            {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
            {}  // Multiplies
        },
        {ptr_d, {n, _1{}, _0{}}}};
229
  } else {
230
231
232
233
234
235
236
237
238
239
240
    args.epilogue = {
        {
            {
                {},  // Accumulator
                {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
                {}  // Multiplies
            },
            {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
            {}  // Multiplies
        },
        {ptr_d, {n, _1{}, _0{}}}};
241
242
243
244
245
246
  }

  return args;
}

template <typename Gemm, bool WithBias>
247
248
249
250
251
252
253
void launch_sm89_fp8_scaled_mm(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
  auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
  Gemm gemm_op;

  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)

  auto status = gemm_op(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
270
271
272
273
274
275
276
void sm89_fp8_dispatch_bias(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
277
278
279
280
  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;
  if (bias) {
281
282
283
284
285
286
287
288
    using Gemm = typename DeviceGemmFp8RowwiseSm89<
        ElementInput,
        ElementOutput,
        AccumElementType,
        CtaShape,
        WarpShape,
        Stages,
        true>::Gemm;
289
290
    return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
  } else {
291
292
293
294
295
296
297
298
    using Gemm = typename DeviceGemmFp8RowwiseSm89<
        ElementInput,
        ElementOutput,
        AccumElementType,
        CtaShape,
        WarpShape,
        Stages,
        false>::Gemm;
299
300
301
302
303
    return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
  }
}

template <typename OutType>
304
305
306
307
308
309
310
void sm89_fp8_dispatch_shape(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
311
312
313
314
315
  uint32_t const m = a.size(0);
  uint32_t const n = out.size(1);

  if (m == 1) {
    if (n <= 8192) {
316
317
318
319
320
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<16, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          7>(out, a, b, scales_a, scales_b, bias);
321
    } else {
322
323
324
325
326
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<32, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          5>(out, a, b, scales_a, scales_b, bias);
327
328
329
330
    }
  } else if (m <= 16) {
    // M in (1, 16]
    if (n <= 8192) {
331
332
333
334
335
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<16, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          4>(out, a, b, scales_a, scales_b, bias);
336
    } else if (n <= 16384) {
337
338
339
340
341
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<32, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          5>(out, a, b, scales_a, scales_b, bias);
342
    } else {
343
344
345
346
347
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<16, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          7>(out, a, b, scales_a, scales_b, bias);
348
349
350
351
    }
  } else if (m <= 64) {
    // M in (16, 64]
    if (n <= 16384) {
352
353
354
355
356
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<32, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          7>(out, a, b, scales_a, scales_b, bias);
357
    } else {
358
359
360
361
362
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<16, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          7>(out, a, b, scales_a, scales_b, bias);
363
364
365
366
    }
  } else if (m <= 128) {
    // M in (64, 128]
    if (n <= 8192) {
367
368
369
370
371
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<64, 64, 128>,
          cutlass::gemm::GemmShape<32, 64, 64>,
          4>(out, a, b, scales_a, scales_b, bias);
372
    } else if (n <= 16384) {
373
374
375
376
377
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<64, 64, 128>,
          cutlass::gemm::GemmShape<32, 64, 64>,
          5>(out, a, b, scales_a, scales_b, bias);
378
    } else {
379
380
381
382
383
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<32, 64, 128>,
          cutlass::gemm::GemmShape<16, 64, 64>,
          5>(out, a, b, scales_a, scales_b, bias);
384
385
386
387
    }
  } else if (m <= 256) {
    // M in (128, 256]
    if (n <= 8192) {
388
389
390
391
392
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 64, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          5>(out, a, b, scales_a, scales_b, bias);
393
    } else if (n <= 16384) {
394
395
396
397
398
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<64, 128, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          7>(out, a, b, scales_a, scales_b, bias);
399
    } else {
400
401
402
403
404
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 64, 128>,
          cutlass::gemm::GemmShape<64, 32, 128>,
          4>(out, a, b, scales_a, scales_b, bias);
405
406
407
408
    }
  } else if (m <= 512) {
    // M in (256, 512)
    if (n <= 16384) {
409
410
411
412
413
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 128, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          2>(out, a, b, scales_a, scales_b, bias);
414
    } else {
415
416
417
418
419
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 128, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          4>(out, a, b, scales_a, scales_b, bias);
420
421
422
423
    }
  } else {
    // M in (512, inf)
    if (n <= 8192) {
424
425
426
427
428
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 128, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          3>(out, a, b, scales_a, scales_b, bias);
429
    } else {
430
431
432
433
434
      return sm89_fp8_dispatch_bias<
          OutType,
          cutlass::gemm::GemmShape<128, 128, 64>,
          cutlass::gemm::GemmShape<64, 32, 64>,
          2>(out, a, b, scales_a, scales_b, bias);
435
436
437
438
439
440
    }
  }
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
441
442
443
444
445
446
447
448
449
450
template <
    typename ElementType,
    typename OutElementType,
    typename AccumElementType,
    typename CTAShape,
    typename ClusterShape,
    typename MainloopScheduleType,
    typename EpilogueScheduleType,
    typename TileSchedulerType = void,
    bool WithBias = false>
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
struct DeviceGemmFp8RowwiseSm90 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");

  // A matrix configuration
  using ElementA = ElementType;               // Element type for A matrix operand
  using LayoutA = cutlass::layout::RowMajor;  // Layout type for A matrix operand
  static constexpr int AlignmentA =
      128 / cutlass::sizeof_bits<ElementA>::value;  // Memory access granularity/alignment of A
                                                    // matrix in units of elements (up to 16 bytes)

  // B matrix configuration
  using ElementB = ElementType;                  // Element type for B matrix operand
  using LayoutB = cutlass::layout::ColumnMajor;  // Layout type for B matrix operand
  static constexpr int AlignmentB =
      128 / cutlass::sizeof_bits<ElementB>::value;  // Memory access granularity/alignment of B
                                                    // matrix in units of elements (up to 16 bytes)

  // C/D matrix configuration
  using ElementC = void;                      // Element type for C matrix operands
  using LayoutC = cutlass::layout::RowMajor;  // Layout type for C matrix operands
  static constexpr int AlignmentC =
      128 / cutlass::sizeof_bits<OutElementType>::value;  // Memory access granularity/alignment of C matrices in
                                                          // units of elements (up to 16 bytes)

  // Output matrix configuration
  using ElementOutput = OutElementType;            // Element type for output matrix operands
  using LayoutOutput = cutlass::layout::RowMajor;  // Layout type for output matrix operands
  static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;

  // // Auxiliary matrix configuration and other fusion types
  // using ElementBias = float;

  // Multiply-accumulate blocking/pipelining details
  using ElementAccumulator = AccumElementType;  // Element type for internal accumulation
  using ElementCompute = float;                 // Element type for compute
  using ElementComputeEpilogue = float;
  using ArchTag = cutlass::arch::Sm90;  // Tag indicating the minimum SM that supports the intended feature
  using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag
  using TileShape = CTAShape;                            // Threadblock-level tile size

  static constexpr bool PONG = false;
  static constexpr bool FAST_ACCUM = true;
  static constexpr bool USE_BIAS = false;

  using StageCountType = cutlass::gemm::collective::StageCountAuto;      // Stage count maximized
                                                                         // based on the tile size
  using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;  // Kernel to launch based on the default
                                                                         // setting in the Collective Builder
  // Implement rowwise scaling epilogue.
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
  using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

  using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      ElementOutput,
      ElementOutput,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
520
521
522

  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

523
524
525
526
527
  using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
      cutlass::multiplies,
      ElementComputeEpilogue,  // First stage output type.
      ElementComputeEpilogue,  // First stage input types.
      cutlass::FloatRoundStyle::round_to_nearest>;
528
529
530

  using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;

531
532
533
534
535
  using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
      cutlass::multiplies,
      ElementOutput,
      ElementComputeEpilogue,  // Second stage input types.
      cutlass::FloatRoundStyle::round_to_nearest>;
536
537
538
539

  using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;

  // With bias
540
541
542
543
544
  using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
      cutlass::multiply_add,
      ElementOutput,
      ElementComputeEpilogue,
      cutlass::FloatRoundStyle::round_to_nearest>;
545
546
547
548
549
  using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;

  using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
550
551
552
553
554
555
556
557
558
559
560
561
562
563
      cutlass::arch::Sm90,
      cutlass::arch::OpClassTensorOp,
      TileShape,
      ClusterShape,
      cutlass::epilogue::collective::EpilogueTileAuto,
      ElementAccumulator,
      ElementComputeEpilogue,
      ElementC,
      LayoutC,
      AlignmentC,
      ElementOutput,
      LayoutOutput,
      AlignmentOutput,
      cutlass::epilogue::TmaWarpSpecialized,
564
565
566
567
568
569
570
571
572
573
574
      EpilogueEVT>::CollectiveOp;

  using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
  using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
  using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;

  using SlowAccum = DefaultSchedule;
  using FastAccum = FastPongSchedule;  // Default apply Pingpong

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
575
576
577
578
579
580
581
582
583
584
585
      ArchTag,
      OperatorClass,
      ElementA,
      LayoutA,
      AlignmentA,
      ElementB,
      LayoutB,
      AlignmentB,
      ElementAccumulator,
      TileShape,
      ClusterShape,
586
587
588
589
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
          sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType>::CollectiveOp;

590
591
592
593
594
  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int, int, int, int>,  // Indicates ProblemShape
      CollectiveMainloop,
      CollectiveEpilogue,
      TileSchedulerType>;
595
596
597
598
599

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

template <typename Gemm, bool WithBias>
600
601
602
603
604
605
606
typename Gemm::Arguments prepare_sm90_fp8_args(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
  using ElementT = typename Gemm::ElementA;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;
  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);
  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
  ElementOutput const* ptr_bias = nullptr;
  if constexpr (WithBias) {
    TORCH_CHECK(bias.has_value())
    ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
  }
  ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
  ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
  ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());

  StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
  StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
  StrideC stride_c;
  StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
633
634
635
636
637
638
639
640
641
  typename Gemm::Arguments args = {
      cutlass::gemm::GemmUniversalMode::kGemm,
      {m, n, k, 1},
      {ptr_a, stride_a, ptr_b, stride_b},
      {{},  // epilogue.thread
       nullptr,
       stride_c,
       ptr_d,
       stride_d}};
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
  if constexpr (WithBias) {
    args.epilogue.thread = {
        {ptr_scales_a},
        {
            {ptr_scales_b},
            {},  // Accumulator
            {}   // Multiplies
        },
        {ptr_bias},
        {},  // Multiplies
    };
  } else {
    args.epilogue.thread = {
        {ptr_scales_a},
        {
            {ptr_scales_b},
            {},  // Accumulator
            {}   // Multiplies
        },
        {},  // Multiplies
    };
  }

  return args;
}

template <typename Gemm, bool WithBias>
669
670
671
672
673
674
675
void launch_sm90_fp8_scaled_mm(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
  auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
  Gemm gemm_op;

  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)

  auto status = gemm_op.run(args, workspace.data_ptr(), stream);

  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
template <
    typename OutType,
    typename CTAShape,
    typename ClusterShape,
    typename MainloopScheduleType,
    typename TileSchedulerType>
void sm90_fp8_dispatch_bias(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias,
    bool fast_accum = true,
    bool use_persistent = false) {
707
708
709
710
711
712
  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;
  using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;

  if (bias) {
713
714
715
716
717
718
719
720
721
722
    using Gemm = typename DeviceGemmFp8RowwiseSm90<
        ElementInput,
        ElementOutput,
        AccumElementType,
        CTAShape,
        ClusterShape,
        MainloopScheduleType,
        EpilogueScheduleType,
        TileSchedulerType,
        true>::Gemm;
723
724
    return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
  } else {
725
726
727
728
729
730
731
732
733
734
    using Gemm = typename DeviceGemmFp8RowwiseSm90<
        ElementInput,
        ElementOutput,
        AccumElementType,
        CTAShape,
        ClusterShape,
        MainloopScheduleType,
        EpilogueScheduleType,
        TileSchedulerType,
        false>::Gemm;
735
736
737
738
739
    return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
  }
}

template <typename OutType>
740
741
742
743
744
745
746
void sm90_fp8_dispatch_shape(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
747
748
749
750
751
752
  uint32_t const m = a.size(0);
  using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
  using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
  using BasicTileScheduler = void;
  if (m <= 1) {
753
754
755
756
757
758
    return sm90_fp8_dispatch_bias<
        OutType,
        Shape<_64, _64, _128>,
        Shape<_1, _8, _1>,
        FastBasicScheduler,
        BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
759
760
761
  }
  if (m <= 64) {
    // m in [1, 64]
762
763
764
765
766
767
    return sm90_fp8_dispatch_bias<
        OutType,
        Shape<_64, _64, _128>,
        Shape<_1, _4, _1>,
        FastPingpongScheduler,
        PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
768
769
  } else if (m <= 256) {
    // m in (64, 256]
770
771
772
773
774
775
    return sm90_fp8_dispatch_bias<
        OutType,
        Shape<_64, _64, _128>,
        Shape<_1, _1, _1>,
        FastPingpongScheduler,
        PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
776
777
  } else if (m <= 1024) {
    // m in (256, 1024]
778
779
780
781
782
783
    return sm90_fp8_dispatch_bias<
        OutType,
        Shape<_128, _128, _128>,
        Shape<_1, _1, _1>,
        FastPingpongScheduler,
        PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
784
785
  } else {
    // m in (1024, inf)
786
787
788
789
790
791
    return sm90_fp8_dispatch_bias<
        OutType,
        Shape<_128, _128, _128>,
        Shape<_2, _1, _1>,
        FastPingpongScheduler,
        PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
792
793
794
795
  }
}
#endif

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
template <
    typename ElementType,
    typename OutElementType,
    typename AccumElementType,
    typename CTAShape,
    typename ClusterShape,
    typename MainloopScheduleType,
    typename EpilogueScheduleType,
    typename TileSchedulerType = void,
    bool WithBias = false>
struct DeviceGemmFp8RowwiseSm100 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
  using TileShape = CTAShape;
  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

  using ElementComputeEpilogue = float;
  using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

  using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      OutElementType,
      OutElementType,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Compute0 = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

  using LayoutA = cutlass::layout::RowMajor;
  static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value;

  using LayoutB = cutlass::layout::ColumnMajor;
  static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value;

  using ElementC = void;
  using LayoutC = cutlass::layout::RowMajor;
  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value;

  using LayoutD = cutlass::layout::RowMajor;
  static constexpr int AlignmentD = AlignmentC;

  using Compute1MulAdd = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
  using Compute1Mul = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute = typename std::conditional_t<
      WithBias,
      cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>,
      cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>;
  using ArgumentType = typename EVTCompute::Arguments;
  // MMA type
  using ElementAccumulator = AccumElementType;

  // Epilogue types
  using ElementCompute = float;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm100,
      cutlass::arch::OpClassTensorOp,
      TileShape,
      ClusterShape,
      cutlass::epilogue::collective::EpilogueTileAuto,
      ElementAccumulator,
      ElementCompute,
      ElementC,
      LayoutC,
      AlignmentC,
      OutElementType,
      LayoutD,
      AlignmentD,
      EpilogueScheduleType,
      EVTCompute>::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm100,
      cutlass::arch::OpClassTensorOp,
      ElementType,
      LayoutA,
      AlignmentA,
      ElementType,
      LayoutB,
      AlignmentB,
      ElementAccumulator,
      TileShape,
      ClusterShape,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
          sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType>::CollectiveOp;
  using GemmKernel =
      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  template <typename Descriptor, typename T>
  static auto args_from_tensor(torch::Tensor const& tensor) {
    using Arguments = typename Descriptor::Arguments;
    auto* data_ptr = static_cast<T*>(tensor.data_ptr());
    static_assert(
        std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>);
    return Arguments{data_ptr};
  }

 public:
  static ArgumentType prepare_args(
      torch::Tensor const& a_scales,
      torch::Tensor const& b_scales,
      std::optional<torch::Tensor> const& bias = std::nullopt) {
    auto a_args = args_from_tensor<ScaleA, float>(a_scales);
    auto b_args = args_from_tensor<ScaleB, float>(b_scales);

    typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};

    if constexpr (WithBias) {
      auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value());
      return ArgumentType{a_args, evt0_args, bias_args, {}};
    } else {
      return ArgumentType{a_args, evt0_args, {}};
    }
  }
};

template <typename GemmType, bool WithBias>
typename GemmType::Gemm::Arguments prepare_sm100_fp8_args(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  using Gemm = typename GemmType::Gemm;
  using ElementT = typename Gemm::ElementA;
  using ElementC = typename Gemm::ElementC;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;
  using GemmKernel = typename Gemm::GemmKernel;

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = StrideC;
  using StrideAux = StrideC;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);

  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());

  StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
  StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
  StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
  StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
  StrideAux aux_stride = stride_d;

  typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b};

  typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1};
  cutlass::KernelHardwareInfo hw_info;
  typename GemmKernel::TileSchedulerArguments scheduler = {};

  auto ptr_c = static_cast<ElementOutput*>(out.data_ptr());

  auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) {
    if constexpr (WithBias) {
      TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided.");
      return typename GemmKernel::EpilogueArguments{
          GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d};
    } else {
      return typename GemmKernel::EpilogueArguments{
          GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d};
    }
  };

  typename GemmKernel::Arguments args{
      cutlass::gemm::GemmUniversalMode::kGemm,
      prob_shape,
      mainloop_args,
      prepare_epilogue_args(bias),
      hw_info,
      scheduler};
  return args;
}

template <typename Gemm, bool WithBias>
void launch_sm100_fp8_scaled_mm(
    torch::Tensor& out,
    torch::Tensor const& a,
    torch::Tensor const& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  auto args = prepare_sm100_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);

  typename Gemm::Gemm gemm_op;
  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
  auto status = gemm_op.run(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType>
void sm100_fp8_dispatch_bias(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
  using CTAShapeDefault = Shape<_256, _128, _64>;
  using ClusterShapeDefault = Shape<_2, _2, _1>;

  using CTAShape256 = Shape<_128, _128, _128>;
  using ClusterShape256 = Shape<_2, _1, _1>;

  using CTAShape64 = Shape<_64, _64, _128>;
  using ClusterShape64 = Shape<_1, _1, _1>;

  using CTAShape16 = Shape<_64, _64, _128>;
  using ClusterShape16 = Shape<_1, _4, _1>;

1035
1036
1037
1038
1039
1040
1041
1042
  using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
  using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
  using TileSchedulerType = void;

  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
  // Gemm type with bias
  using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShapeDefault,
      ClusterShapeDefault,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      true>;
  using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape256,
      ClusterShape256,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      true>;
  using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape64,
      ClusterShape64,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      true>;
  using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape16,
      ClusterShape16,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      true>;

  // Gemm type without bias
  using GemmDefault = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShapeDefault,
      ClusterShapeDefault,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      false>;
  using Gemm256 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape256,
      ClusterShape256,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      false>;
  using Gemm64 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape64,
      ClusterShape64,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      false>;
  using Gemm16 = DeviceGemmFp8RowwiseSm100<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShape16,
      ClusterShape16,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      false>;

  // next power of 2 (minimum 16)
  uint32_t const m = a.size(0);
  uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));

1131
  if (bias) {
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
    if (mp2 <= 16) {
      // m in [1, 16]
      return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
    } else if (mp2 <= 64) {
      // m in (16, 64]
      return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
    } else if (mp2 <= 256) {
      // m in (64, 256]
      return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
    } else {
      // m in (256, inf]
      return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
    }
1145
  } else {
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    if (mp2 <= 16) {
      // m in [1, 16]
      return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
    } else if (mp2 <= 64) {
      // m in (16, 64]
      return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
    } else if (mp2 <= 256) {
      // m in (64, 256]
      return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
    } else {
      return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
    }
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
  }
}

template <typename OutType>
void sm100_fp8_dispatch_shape(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  return sm100_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
}
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447

template <
    typename ElementType,
    typename OutElementType,
    typename AccumElementType,
    typename CTAShape,
    typename ClusterShape,
    typename MainloopScheduleType,
    typename EpilogueScheduleType,
    typename TileSchedulerType = void,
    bool WithBias = false>
struct DeviceGemmFp8RowwiseSm120 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
  using TileShape = CTAShape;
  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

  using ElementComputeEpilogue = float;
  using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

  using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue,
      ElementComputeEpilogue,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      OutElementType,
      OutElementType,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Compute0 = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

  using LayoutA = cutlass::layout::RowMajor;
  static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value;

  using LayoutB = cutlass::layout::ColumnMajor;
  static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value;

  using ElementC = void;
  using LayoutC = cutlass::layout::RowMajor;
  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value;

  using LayoutD = cutlass::layout::RowMajor;
  static constexpr int AlignmentD = AlignmentC;

  using Compute1MulAdd = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
  using Compute1Mul = cutlass::epilogue::fusion::
      Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute = typename std::conditional_t<
      WithBias,
      cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>,
      cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>;
  using ArgumentType = typename EVTCompute::Arguments;
  // MMA type
  using ElementAccumulator = AccumElementType;

  // Epilogue types
  using ElementCompute = float;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm120,
      cutlass::arch::OpClassTensorOp,
      TileShape,
      ClusterShape,
      cutlass::epilogue::collective::EpilogueTileAuto,
      ElementAccumulator,
      ElementCompute,
      ElementC,
      LayoutC,
      AlignmentC,
      OutElementType,
      LayoutD,
      AlignmentD,
      EpilogueScheduleType,
      EVTCompute>::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm120,
      cutlass::arch::OpClassTensorOp,
      ElementType,
      LayoutA,
      AlignmentA,
      ElementType,
      LayoutB,
      AlignmentB,
      ElementAccumulator,
      TileShape,
      ClusterShape,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
          sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType>::CollectiveOp;
  using GemmKernel =
      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  template <typename Descriptor, typename T>
  static auto args_from_tensor(torch::Tensor const& tensor) {
    using Arguments = typename Descriptor::Arguments;
    auto* data_ptr = static_cast<T*>(tensor.data_ptr());
    static_assert(
        std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>);
    return Arguments{data_ptr};
  }

 public:
  static ArgumentType prepare_args(
      torch::Tensor const& a_scales,
      torch::Tensor const& b_scales,
      std::optional<torch::Tensor> const& bias = std::nullopt) {
    auto a_args = args_from_tensor<ScaleA, float>(a_scales);
    auto b_args = args_from_tensor<ScaleB, float>(b_scales);

    typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};

    if constexpr (WithBias) {
      auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value());
      return ArgumentType{a_args, evt0_args, bias_args, {}};
    } else {
      return ArgumentType{a_args, evt0_args, {}};
    }
  }
};

template <typename GemmType, bool WithBias>
typename GemmType::Gemm::Arguments prepare_sm120_fp8_args(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  using Gemm = typename GemmType::Gemm;
  using ElementT = typename Gemm::ElementA;
  using ElementC = typename Gemm::ElementC;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;
  using GemmKernel = typename Gemm::GemmKernel;

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = StrideC;
  using StrideAux = StrideC;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);

  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());

  StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
  StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
  StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
  StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
  StrideAux aux_stride = stride_d;

  typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b};

  typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1};
  cutlass::KernelHardwareInfo hw_info;
  typename GemmKernel::TileSchedulerArguments scheduler = {};

  auto ptr_c = static_cast<ElementOutput*>(out.data_ptr());

  auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) {
    if constexpr (WithBias) {
      TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided.");
      return typename GemmKernel::EpilogueArguments{
          GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d};
    } else {
      return typename GemmKernel::EpilogueArguments{
          GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d};
    }
  };

  typename GemmKernel::Arguments args{
      cutlass::gemm::GemmUniversalMode::kGemm,
      prob_shape,
      mainloop_args,
      prepare_epilogue_args(bias),
      hw_info,
      scheduler};
  return args;
}

template <typename Gemm, bool WithBias>
void launch_sm120_fp8_scaled_mm(
    torch::Tensor& out,
    torch::Tensor const& a,
    torch::Tensor const& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  auto args = prepare_sm120_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);

  typename Gemm::Gemm gemm_op;
  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
  auto status = gemm_op.run(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType>
void sm120_fp8_dispatch_bias(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  using CTAShapeDefault = Shape<_128, _128, _128>;
  using ClusterShapeDefault = Shape<_1, _1, _1>;

  using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
  using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
  using TileSchedulerType = void;

  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;

  using BiasGemmDefault = DeviceGemmFp8RowwiseSm120<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShapeDefault,
      ClusterShapeDefault,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      true>;

  using GemmDefault = DeviceGemmFp8RowwiseSm120<
      ElementInput,
      ElementOutput,
      AccumElementType,
      CTAShapeDefault,
      ClusterShapeDefault,
      MainloopScheduleType,
      EpilogueScheduleType,
      TileSchedulerType,
      false>;

  if (bias) {
    return launch_sm120_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
  } else {
    return launch_sm120_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
  }
}

template <typename OutType>
void sm120_fp8_dispatch_shape(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const c10::optional<torch::Tensor>& bias) {
  return sm120_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
}
1448
1449
#endif

1450
1451
1452
1453
1454
1455
1456
torch::Tensor fp8_scaled_mm(
    const torch::Tensor& mat_a,
    const torch::Tensor& mat_b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Dtype& out_dtype,
    const c10::optional<torch::Tensor>& bias) {
1457
1458
1459
1460
1461
  TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
  TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
  TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
  TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
  TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
1462
  TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
1463
1464
  TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");

1465
1466
1467
1468
  TORCH_CHECK(
      (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
  TORCH_CHECK(
      (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
  TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
  TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
  TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");

  TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
  TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
  TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
  TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
  TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
  TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");

  if (bias) {
    TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
    TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
    TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
  }

  torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
  TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");

  auto sm_version = getSMVersion();

1491
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
1492
1493
1494
1495
1496
1497
1498
1499
  if (sm_version >= 120) {
    if (out_dtype == torch::kBFloat16) {
      sm120_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm120_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  } else if (sm_version >= 100) {
1500
1501
1502
1503
1504
1505
1506
1507
1508
    if (out_dtype == torch::kBFloat16) {
      sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm100_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  }
#endif

1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
  if (sm_version >= 90) {
    if (out_dtype == torch::kBFloat16) {
      sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  }
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
  if (sm_version == 89) {
    if (out_dtype == torch::kBFloat16) {
      sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  }
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}