enums.py 350 Bytes
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

import enum

5
6
7
8
class ModelType(enum.Enum):
    encoder_or_decoder = 1
    encoder_and_decoder = 2

9
10
11
12
13
14
15
16
17
18
19
class LayerType(enum.Enum):
    encoder = 1
    decoder = 2
 
class AttnType(enum.Enum):
    self_attn = 1
    cross_attn = 2

class AttnMaskType(enum.Enum):
    padding = 1
    causal = 2