constants.py 2.14 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Enums for e2e transformer"""
import torch
Jan Bielak's avatar
Jan Bielak committed
7
import torch.distributed
8
import transformer_engine_torch as tex
Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
14
15
16
17


"""
This is a map: torch.dtype -> int
Used for passing dtypes into cuda
extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
cyanguwa's avatar
cyanguwa committed
18
    torch.uint8: tex.DType.kByte,
19
20
    torch.float8_e4m3fn: tex.DType.kFloat8E4M3,
    torch.float8_e5m2: tex.DType.kFloat8E5M2,
yuguo's avatar
yuguo committed
21
    torch.int8: tex.DType.kInt8,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
26
27
    torch.int32: tex.DType.kInt32,
    torch.float32: tex.DType.kFloat32,
    torch.half: tex.DType.kFloat16,
    torch.bfloat16: tex.DType.kBFloat16,
}

28
29
30
31
32
33
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
34
35
36
37
TE_DType_To_Torch = {
    tex.DType.kByte: torch.uint8,
    tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
    tex.DType.kFloat8E5M2: torch.float8_e5m2,
38
    tex.DType.kInt8: torch.int8,
39
40
41
42
43
44
    tex.DType.kInt32: torch.int32,
    tex.DType.kFloat32: torch.float32,
    tex.DType.kFloat16: torch.half,
    tex.DType.kBFloat16: torch.bfloat16,
}

45
46
47
48
49
50
51
52
53
AttnMaskTypes = (
    "no_mask",
    "padding",
    "causal",
    "padding_causal",
    "causal_bottom_right",
    "padding_causal_bottom_right",
    "arbitrary",
)
Przemek Tredak's avatar
Przemek Tredak committed
54
55
56

AttnTypes = ("self", "cross")

57
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi")
58

59
QKVLayouts = (
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
    "t3hd",
    "th3d",
    "thd_t2hd",
    "thd_th2d",
    "thd_thd_thd",
75
76
77
78
79
80
81
82
83
84
    "sbhd_bshd_bshd",
    "bshd_sbhd_sbhd",
    "thd_bshd_bshd",
    "thd_sbhd_sbhd",
    "paged_kv_bshd_bshd_bshd",
    "paged_kv_bshd_sbhd_sbhd",
    "paged_kv_sbhd_bshd_bshd",
    "paged_kv_sbhd_sbhd_sbhd",
    "paged_kv_thd_bshd_bshd",
    "paged_kv_thd_sbhd_sbhd",
85
)
86

Przemek Tredak's avatar
Przemek Tredak committed
87
88
89
90
LayerTypes = ("encoder", "decoder")

GemmParallelModes = ("row", "column", None)

Jan Bielak's avatar
Jan Bielak committed
91
dist_group_type = torch.distributed.ProcessGroup
92
93

MXFP8_BLOCK_SCALING_SIZE = 32
94
95

NVFP4_BLOCK_SCALING_SIZE = 16