linear.py 1.21 KB
Newer Older
1
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
2

3
4
from .colo_module import ColoModule

5

6
class ColoLinear(ColoModule):
7

8
9
10
    def __init__(self):
        super(ColoLinear, self).__init__()
        self._register_shard_params(['weight', 'bias'])
11

12
    def register(self, compute_pattern, pg: ProcessGroup):
13
14
        if not compute_pattern in self._allowed_patterns:
            if ComputePattern.TP1D == compute_pattern:
15
                self._set_TP1D(pg)
16

17
    def _set_TP1D(self, pg):
18
19
20
21
22
        # TP1D Row Linear
        _compute_pattern = ComputePattern.TP1D
        self._register_allowed_patterns(
            compute_pattern=_compute_pattern,
            dist_specs={
23
                'weight': ShardSpec([-1], [pg.tp_world_size()]),
24
                'bias': None
25
            },
26
            mode='row',
27
28
29
30
31
32
        )

        # TP1D Col Linear
        self._register_allowed_patterns(
            compute_pattern=_compute_pattern,
            dist_specs={
33
34
                'weight': ShardSpec([0], [pg.tp_world_size()]),
                'bias': ShardSpec([0], [pg.tp_world_size()])
35
            },
36
            mode='col',
37
38
        )

39
        self._set_default(compute_pattern=_compute_pattern, target_mode='row')