vllm_cutlass_library_extension.py 2.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
from typing import Union
6
7
8
9
10
11
12
13
14
15
16
17
18
19

from cutlass_library import *

#
#   Extend cutlass library with custom types, and missing values
#


class VLLMDataType(enum.Enum):
    u4b8 = enum_auto()
    u8b128 = enum_auto()


class MixedInputKernelScheduleType(enum.Enum):
20
21
22
    TmaWarpSpecialized = enum_auto()
    TmaWarpSpecializedPingpong = enum_auto()
    TmaWarpSpecializedCooperative = enum_auto()
23
24


25
VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
26
27
28
29
30
31
32
    **DataTypeNames,  # type: ignore
    **{
        VLLMDataType.u4b8: "u4b8",
        VLLMDataType.u8b128: "u8b128",
    }
}

33
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
34
35
36
37
38
39
40
    **DataTypeTag,  # type: ignore
    **{
        VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
        VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
    }
}

41
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
42
43
44
45
46
47
48
    **DataTypeSize,  # type: ignore
    **{
        VLLMDataType.u4b8: 4,
        VLLMDataType.u8b128: 8,
    }
}

49
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
50
51
52
53
54
55
56
57
58
59
    VLLMDataType.u4b8: "vllm::kU4B8",
    VLLMDataType.u8b128: "vllm::kU8B128",
    DataType.u4: "vllm::kU4",
    DataType.u8: "vllm::kU8",
    DataType.s4: "vllm::kS4",
    DataType.s8: "vllm::kS8",
    DataType.f16: "vllm::kFloat16",
    DataType.bf16: "vllm::kBfloat16",
}

60
VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
61
62
63
64
65
66
67
68
69
    DataType.u8: "at::ScalarType::Byte",
    DataType.s8: "at::ScalarType::Char",
    DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
    DataType.s32: "at::ScalarType::Int",
    DataType.f16: "at::ScalarType::Half",
    DataType.bf16: "at::ScalarType::BFloat16",
    DataType.f32: "at::ScalarType::Float",
}

70
VLLMKernelScheduleTag: dict[Union[
71
72
73
    MixedInputKernelScheduleType, KernelScheduleType], str] = {
        **KernelScheduleTag,  # type: ignore
        **{
74
75
76
77
78
79
            MixedInputKernelScheduleType.TmaWarpSpecialized:
            "cutlass::gemm::KernelTmaWarpSpecialized",
            MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
            "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
            MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
            "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
80
81
        }
    }