"test/tuner_test/naive_trial.py" did not exist on "ac5fda4d5d2a798b540ccab3b9e953e943f40285"
dispatch_policy.hpp 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once

#include <cutlass/gemm/dispatch_policy.hpp>

namespace cutlass::gemm {

//////////////////////////////////////////////////////////////////////////////

// FP8 related policies (including Blocked Scaled Accumulation)
//  `ScaleGranularityM` specifies scaling granularity along M, while zero-value
//  `ScaleGranularityM` indicates that scaling granularity is
//  `size<0>(TileShape_MNK{})` along M.
template <int ScaleGranularityM = 0>
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};

// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
19
20
21
22
23
24
25
26
template <
    int Stages_,
    class ClusterShape_ = Shape<_1, _1, _1>,
    class KernelSchedule = KernelTmaWarpSpecialized,
    int ScaleGranularityM = 0  // `ScaleGranularityM` specifies scaling granularity along M,
                               // while zero-value `ScaleGranularityM` indicates that scaling
                               // granularity is `size<0>(TileShape_MNK{})` along M.
    >
27
28
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
    : MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
29
30
31
32
  static_assert(
      cute::
          is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
      "KernelSchedule must be one of the warp specialized policies");
33
34
35
36
37
};

//////////////////////////////////////////////////////////////////////////////

}  // namespace cutlass::gemm