distributed.py 2.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Methods needed for distributed training."""

from contextlib import contextmanager
from typing import Optional, Union, Tuple

import paddle

import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops

from .constants import dist_group_type

_weight_split_axis = {
    'transformer_engine': {
        'row': 1,
        'column': 0
    },
    'paddle': {
        'row': 0,
        'column': 1
    }
}


def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
                                enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]:
    """Get TP group and world size using Fleet API"""
    if not (paddle.distributed.is_initialized() and enable_tp):
        return None, 1
    model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
                            if tp_group is None else tp_group)
    world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
                  if tp_group is None else tp_group.nranks)
    return model_parallel_group, world_size


@contextmanager
42
def track_rng_state(enable: bool, **kwargs) -> None:
43
44
45
46
47
    """
    Applies get_rng_state_tracker().rng_state() to the context.
    If not enabled, it does nothing.
    """
    if enable:
48
        with get_rng_state_tracker().rng_state(**kwargs):
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            yield
    else:
        yield


def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None:
    """Set distributed attributes for the input tensor"""
    tensor.is_distributed = is_parallel
    if is_parallel:
        tensor.split_axis = axis


def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
                                parallel_mode: Optional[str], backend: str) -> None:
    """Set distributed attributes for the weight tensor"""
    if not is_parallel or parallel_mode is None:
        return
    set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode])


def allreduce(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_

    # All-reduce.
    output = mp_ops._mp_allreduce(
        input_,
        group=tp_group,
        use_calc_stream=True,
        use_model_parallel=True,
    )

    return output


def identity(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
    """
    Identity when forward.
    Allreduce across model parallel group when backward.
    """
    output = mp_ops._c_identity(input_, group=tp_group)

    return output