enums.py 446 Bytes
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

import enum

class LayerType(enum.Enum):
    encoder = 1
    decoder = 2
liangjing's avatar
v1  
liangjing committed
8
9
10
    retro_encoder = 3
    retro_decoder = 4
    retro_decoder_with_retriever = 5
11
12
13
14
15
16
17
18
 
class AttnType(enum.Enum):
    self_attn = 1
    cross_attn = 2

class AttnMaskType(enum.Enum):
    padding = 1
    causal = 2
Jared Casper's avatar
Jared Casper committed
19
20
21

# For backward compatibility with old model checkpoints
from megatron.core.enums import ModelType