utils.py 8.25 KB
Newer Older
1
from typing import Dict, Iterator, List, Tuple, Union
2

3
import torch
4
import torch.nn as nn
5

6
from colossalai.tensor.colo_tensor import ColoTensor
7
8


9
def all_gather_simulator(target_pair):
10
    """
11
12
13
14
15
    Simulating all-gather operation, analyze the communication cost
    and simulate the influence of the DimSpec.

    We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
    Therefore, all gather operation just remove the last element in shard list,
16
    e.g.:
17
18
19
20
        all-gather(S01) -> S0

    Argument:
        target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
21
        and the second element describes which logical axis will be sharded in that dimension.
22
    """
23
24
25
26
27
28
29
    _, shard_list = target_pair
    new_shard_list = shard_list[:-1]

    return new_shard_list


def all_to_all_simulator(f_target_pair, b_target_pair):
30
    """
31
32
33
34
    Simulating all-to-all operation, analyze the communication cost
    and simulate the influence of the DimSpec.

    We BANNED all representations which shard_list in decreasing order,
35
    such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
36
37
38
    Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
    Argument:
        target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
39
        and the second element describes which logical axis will be sharded in that dimension.
40
    e.g.:
41
42
43
        all-to-all(S0, S1) -> [S01, R]
        all-to-all(S0, R) -> [R, S0]
    Otherwise, we extend the front shard_list to behind.
44
    e.g.:
45
        all-to-all(R, S1) -> [S1, R]
46

47
48
    Argument:
        target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
49
        and the second element describes which logical axis will be sharded in that dimension.
50
    """
51
52
53
54
55
56
57
58
59
60
61
62
63
    _, f_shard_list = f_target_pair
    _, b_shard_list = b_target_pair
    if not len(b_shard_list):
        b_shard_list.extend(f_shard_list)
        f_shard_list = []
    else:
        f_shard_list.extend(b_shard_list)
        b_shard_list = []

    return f_shard_list, b_shard_list


def shard_simulator(target_pair, legal_sharding_dims):
64
    """
65
66
67
68
    Simulating shard operation, analyze the communication cost(always ZERO)
    and simulate the influence of the DimSpec.

    We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
69
    In addition, We BANNED all representations which shard_list in decreasing order,
70
71
72
73
74
75
76
77
78
79
    such as S10, so shard(S0) -> S10 is NOT allowed.
    Therefore, for the R dimension, we could just append any legal sharding dim on it.
    e.g.:
        shard(R) -> S0
    For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
    e.g:
        shard(S0) -> S01

    Argument:
        target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
80
        and the second element describes which logical axis will be sharded in that dimension.
81
    """
82
83
84
85
86
87
88
89
90
91
92
    _, shard_list = target_pair
    shard_list_list = []
    for dim in legal_sharding_dims:
        if len(shard_list) != 0 and dim <= shard_list[-1]:
            continue
        new_shard_list = shard_list + [dim]
        shard_list_list.append(new_shard_list)

    return shard_list_list


93
def mix_gather_simulator(f_target_pair, b_target_pair):
94
    """
95
96
97
98
99
100
101
    Assume index of f and b target pairs are 'f' and 'b'
    S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
    S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
    S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
    RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
    S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
    RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
102
    """
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    if f_target_pair[1] and b_target_pair[1]:
        leading_dim = b_target_pair[1] > f_target_pair[1]
        return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
    if f_target_pair[1]:
        leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
        return [
            f_target_pair[0],
        ], [int(leading_dim), int(leading_dim)]
    if b_target_pair[1]:
        leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
        return [
            b_target_pair[0],
        ], [int(leading_dim), int(leading_dim)]


118
119
120
# The function is credited to PyTorch Team
def named_params_with_colotensor(
    module: nn.Module,
121
    prefix: str = "",
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
    r"""Returns an iterator over module parameters (together with the
    ColoTensor parameters), yielding both the name of the parameter
    as well as the parameter itself. This is typically passed to a
    :class:torchshard._shard.sharded_optim.ShardedOptimizer

    Args:
        prefix (str): prefix to prepend to all parameter names.
        recurse (bool): if True, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.

    Yields:
        (string, Union[Tensor, ColoTensor]): Tuple containing
            the name and parameter (or ColoTensor parameter)

139
    Example:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        >>> model = torch.nn.Linear(*linear_size)
        >>> delattr(model.weight)
        >>> setattr(model.weight, ColoTensor(...))
        >>> for name, param in named_params_with_colotensor(model):
        >>>    if name in ['weight']:
        >>>        print(param.size())

    """
    modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]

    memo = set()
    for mod_prefix, mod in modules:
        # find all sharded tensor params
        for name, val in vars(mod).items():
            if isinstance(val, ColoTensor) and val not in memo:
                memo.add(val)
157
                name = mod_prefix + ("." if mod_prefix else "") + name
158
159
160
161
162
163
                yield name, val

    # find all nn.Parameters
    for name, val in module.named_parameters():
        yield name, val

164

Jiarui Fang's avatar
Jiarui Fang committed
165
166
def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
    return ColoTensor(tensor)
167
168
169
170
171


def convert_parameter(module: torch.nn.Module, param_name: str):
    # Perform some validation first.
    if not hasattr(module, param_name):
172
        raise ValueError(f"module: {module} does not have parameter with name: {param_name}")
173
174
175
176

    tensor = getattr(module, param_name)
    if not isinstance(tensor, torch.Tensor):
        raise ValueError(
177
178
            f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
        )
179
180

    if not tensor.is_contiguous():
181
        raise ValueError(f"param: {param_name} is not a contiguous Tensor")
182
183
184

    st = _convert_tensor(tensor)

Jiarui Fang's avatar
Jiarui Fang committed
185
    # Replace param with ColoTensor.
186
187

    # Need to delete the attribute first since param_name might be
Jiarui Fang's avatar
Jiarui Fang committed
188
    # torch.nn.Parameter and can't be replaced with ColoTensor which is
189
190
191
192
193
    # not torch.nn.Parameter.
    delattr(module, param_name)

    # Now we can set the attribute appropriately.
    setattr(module, param_name, st)
194
195
196


def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
197
    """
198
    This method is used to convert the negative dim value to positive.
199
    """
200
201
202
203
204
205
206
207
208
209
210
    dims_to_convert = []
    for dim, mesh_list in dim_partition_dict.items():
        if dim < 0:
            dims_to_convert.append(dim)
    for dim in dims_to_convert:
        dim_partition_dict.pop(dim)
        dim_partition_dict[dim_size + dim] = mesh_list
    return dim_partition_dict


def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
211
    """
212
213
214
215
216
    This method is used to merge the different key value which points to same physical position.

    For example:
        dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
        In this method, above dim_partition_dict will be converted to {1: [0, 1]}
217
    """
218
219
220
221
222
223
224
225
226
227
    converted_dim_partition_dict = {}
    for dim, mesh_list in dim_partition_dict.items():
        if dim < 0:
            dim = dim_size + dim
        if dim not in converted_dim_partition_dict:
            converted_dim_partition_dict[dim] = mesh_list
        else:
            converted_dim_partition_dict[dim].extend(mesh_list)

    return converted_dim_partition_dict