machete_collective_builder.cuh 1.29 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#pragma once

#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "machete_mainloop.cuh"

namespace cutlass::gemm::collective {
using namespace cute;

struct MacheteKernelTag {};

template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
          class ElementPairB_, class GmemLayoutB_, int AlignmentB,
          class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
          class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
    MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
    GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
    ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
    KernelScheduleType,
    cute::enable_if_t<(
        cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
        cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
        cute::is_same_v<KernelScheduleType,
                        KernelTmaWarpSpecializedCooperative>)>> {
  using CollectiveOp = machete::MacheteCollectiveMma<
      ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
      AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
      StageCountType, KernelScheduleType>;
};

};  // namespace cutlass::gemm::collective