constants.py 1.2 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Constants"""

6
7
from enum import Enum

8
import paddle
9

10
import transformer_engine_paddle as tex
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


class FP8FwdTensors(Enum):
    """Used as named indices on the `scale`, `scale_inv`,
    and `amax` tensors in the `FP8TensorMeta` class."""
    GEMM1_INPUT = 0
    GEMM1_WEIGHT = 1
    GEMM1_OUTPUT = 2
    GEMM2_INPUT = 3
    GEMM2_WEIGHT = 4
    GEMM2_OUTPUT = 5


class FP8BwdTensors(Enum):
    """Used as named indices on the `scale`, `scale_inv`,
    and `amax` tensors in the `FP8TensorMeta` class."""
    GRAD_OUTPUT1 = 0
    GRAD_INPUT1 = 1
    GRAD_OUTPUT2 = 2
    GRAD_INPUT2 = 3


33
34
35
36
37
38
39
40
41
42
"""
Map from paddle dtype to TE dtype
"""
TE_DType = {
    paddle.uint8: tex.DType.kByte,
    paddle.int32: tex.DType.kInt32,
    paddle.float32: tex.DType.kFloat32,
    paddle.float16: tex.DType.kFloat16,
    paddle.bfloat16: tex.DType.kBFloat16,
}
Shijie's avatar
Shijie committed
43
44
45
46
47
48

AttnMaskTypes = ("causal", "padding", "no_mask")

AttnTypes = ("self", "cross")

LayerTypes = ("encoder", "decoder")
49
50
51
52

GemmParallelModes = ("row", "column", None)

dist_group_type = paddle.distributed.collective.Group
Tian Zheng's avatar
Tian Zheng committed
53
54

RecomputeFunctionNames = ('unpack', 'backward')