sharding_configs.py 4.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
from itertools import product
from transformer_engine.jax.sharding import ShardingType
from transformer_engine.jax.softmax import SoftmaxType
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\


class ShardingConfigs(object):
    
    def __init__(self, num_gpus=len(jax.devices())):
        super().__init__()
        if num_gpus < 2:
            raise ValueError(f"ShardingConfig: Need at least 2 GPUs.")
        
        self.device_count = min(num_gpus, 8)
        mesh_configs = [
            ((self.device_count, 1),    ("dp", None), ShardingType.DP),
            ((self.device_count, 1),    ("tp", None), ShardingType.TP_COL),
            ((self.device_count, 1),    ("tp", None), ShardingType.TP_ROW)
        ]
        if self.device_count >= 4:
            mesh_configs += [
                ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
                ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW),
            ]
        if self.device_count >= 6:
            mesh_configs += [
                ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL),
                ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW),
            ]

        layernorm_collectives = {
            ShardingType.DP        : {'all-reduce': 2, 'other': 0},
            ShardingType.TP_COL    : {'all-reduce': 0, 'other': 0},
            ShardingType.DP_TP_COL : {'all-reduce': 2, 'other': 0}
        }
        self.layernorm_refs = [
            mesh_config + (layernorm_collectives[mesh_config[2]], ) \
                for mesh_config in mesh_configs \
                    if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
        ]

        self.softmax_types = [
            SoftmaxType.SCALED,
            SoftmaxType.SCALED_MASKED,
            SoftmaxType.SCALED_UPPER_TRIANG_MASKED
        ]
        softmax_collectives = {
            ShardingType.DP        : {'all-reduce': 1, 'other': 0},
            ShardingType.TP_COL    : {'all-reduce': 1, 'other': 0},
            ShardingType.TP_ROW    : {'all-reduce': 1, 'other': 0},
            ShardingType.DP_TP_COL : {'all-reduce': 1, 'other': 0},
            ShardingType.DP_TP_ROW : {'all-reduce': 1, 'other': 0}
        }
        self.softmax_refs = [
            mesh_config + (softmax_collectives[mesh_config[2]], ) for mesh_config in mesh_configs
        ]

        self.self_attn_bias_types = [
            AttnBiasType.NO_BIAS,
            AttnBiasType.PRE_SCALE_BIAS,
            AttnBiasType.POST_SCALE_BIAS
        ]
        self.self_attn_mask_types = [
            AttnMaskType.CAUSAL_MASK,
            AttnMaskType.PADDING_MASK,
            AttnMaskType.NO_MASK
        ]
        self_attn_collectives = {
            ShardingType.DP : {
                AttnBiasType.NO_BIAS         : {'all-reduce': 1, 'other': 0},
                AttnBiasType.PRE_SCALE_BIAS  : {'all-reduce': 2, 'other': 0},
                AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0},
            },
            ShardingType.TP_COL : {
                AttnBiasType.NO_BIAS         : {'all-reduce': 1, 'other': 0},
                AttnBiasType.PRE_SCALE_BIAS  : {'all-reduce': 1, 'other': 0},
                AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 1, 'other': 0}
            },
            ShardingType.DP_TP_COL : {
                AttnBiasType.NO_BIAS         : {'all-reduce': 1, 'other': 0},
                AttnBiasType.PRE_SCALE_BIAS  : {'all-reduce': 2, 'other': 0},
                AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0}
            },
        }
        self.self_attn_refs = [
            mesh_config + (bias_type, self_attn_collectives[mesh_config[2]][bias_type]) \
                for mesh_config, bias_type in product(mesh_configs, self.self_attn_bias_types) \
                    if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
        ]

        self.cross_attn_mask_types = [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.NO_MASK
        ]
        self.cross_attn_refs = [
            mesh_config + ({'all-reduce': 1, 'other': 0}, ) \
                for mesh_config in mesh_configs \
                    if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
        ]