replace_policy.py 930 Bytes
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

aiss's avatar
aiss committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from .containers import HFGPT2LayerPolicy
from .containers import HFBertLayerPolicy
from .containers import BLOOMLayerPolicy
from .containers import HFGPTJLayerPolicy
from .containers import HFGPTNEOLayerPolicy
from .containers import GPTNEOXLayerPolicy
from .containers import HFOPTLayerPolicy
from .containers import MegatronLayerPolicy
from .containers import HFDistilBertLayerPolicy
from .containers import HFCLIPLayerPolicy
from .containers import UNetPolicy
from .containers import VAEPolicy

# transformer-based policies
aiss's avatar
aiss committed
20
replace_policies = [
aiss's avatar
aiss committed
21
22
    HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
    HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy
aiss's avatar
aiss committed
23
]
aiss's avatar
aiss committed
24
25
26

# non-transformer-based policies
generic_policies = [UNetPolicy, VAEPolicy]