vllm_cutlass_library_extension.py 1.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import enum
from typing import Dict, Union

from cutlass_library import *

#
#   Extend cutlass library with custom types, and missing values
#


class VLLMDataType(enum.Enum):
    u4b8 = enum_auto()
    u8b128 = enum_auto()


class MixedInputKernelScheduleType(enum.Enum):
    TmaWarpSpecializedMixedInput = enum_auto()
    TmaWarpSpecializedPingpongMixedInput = enum_auto()
    TmaWarpSpecializedCooperativeMixedInput = enum_auto()


VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
    **DataTypeNames,  # type: ignore
    **{
        VLLMDataType.u4b8: "u4b8",
        VLLMDataType.u8b128: "u8b128",
    }
}

VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
    **DataTypeTag,  # type: ignore
    **{
        VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
        VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
    }
}

VLLMKernelScheduleTag: Dict[Union[
    MixedInputKernelScheduleType, KernelScheduleType], str] = {
        **KernelScheduleTag,  # type: ignore
        **{
            MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
            "cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
            MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
            "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
            MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
            "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
        }
    }