enums.py 839 Bytes
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
3
4
5
6
7
8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import enum


# can we get rid of this?
# it's being used in pipeline schedules
class ModelType(enum.Enum):
xingjinliang's avatar
xingjinliang committed
9
10
11
12
13
14
    """Model Type

    encoder_or_decoder for bert, gpt etc
    encoder_and_decoder for multimodal , T5 etc
    """

liangjing's avatar
v1  
liangjing committed
15
16
17
18
19
20
21
22
23
24
    encoder_or_decoder = 1
    encoder_and_decoder = 2


# class LayerType(enum.Enum):
#     encoder = 1
#     decoder = 2


class AttnType(enum.Enum):
xingjinliang's avatar
xingjinliang committed
25
26
    """Attention type"""

liangjing's avatar
v1  
liangjing committed
27
28
29
30
31
    self_attn = 1
    cross_attn = 2


class AttnMaskType(enum.Enum):
xingjinliang's avatar
xingjinliang committed
32
33
    """Attention Mask Type"""

liangjing's avatar
v1  
liangjing committed
34
35
    padding = 1
    causal = 2
xingjinliang's avatar
xingjinliang committed
36
37
38
39
40
41
42
43
44
45
46
47
48
    no_mask = 3  # only used for TE
    padding_causal = 4  # only used for thd attention
    arbitrary = 5


class AttnBackend(enum.Enum):
    """Attention Backend"""

    flash = 1
    fused = 2
    unfused = 3
    local = 4
    auto = 5