constants.py 1.05 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Enums for e2e transformer"""
import torch
Jan Bielak's avatar
Jan Bielak committed
7
import torch.distributed
Przemek Tredak's avatar
Przemek Tredak committed
8
9
10
11
12
13
14
15
16
17
import transformer_engine_extensions as tex


"""
This is a map: torch.dtype -> int
Used for passing dtypes into cuda
extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
cyanguwa's avatar
cyanguwa committed
18
    torch.uint8: tex.DType.kByte,
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21
22
23
24
    torch.int32: tex.DType.kInt32,
    torch.float32: tex.DType.kFloat32,
    torch.half: tex.DType.kFloat16,
    torch.bfloat16: tex.DType.kBFloat16,
}

25
AttnMaskTypes = ("causal", "padding", "no_mask")
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28

AttnTypes = ("self", "cross")

29
30
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias")

31
32
33
34
35
QKVLayouts = (
    "sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
    "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
    "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd")

Przemek Tredak's avatar
Przemek Tredak committed
36
37
38
39
LayerTypes = ("encoder", "decoder")

GemmParallelModes = ("row", "column", None)

Jan Bielak's avatar
Jan Bielak committed
40
dist_group_type = torch.distributed.ProcessGroup