distributed.py 3.88 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Methods needed for distributed training."""

Tian Zheng's avatar
Tian Zheng committed
6
7
import os
import warnings
8
from contextlib import contextmanager
Tian Zheng's avatar
Tian Zheng committed
9
from typing import Any, Optional, Union, Tuple
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

import paddle

import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops

from .constants import dist_group_type

_weight_split_axis = {
    'transformer_engine': {
        'row': 1,
        'column': 0
    },
    'paddle': {
        'row': 0,
        'column': 1
    }
}


def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
                                enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]:
    """Get TP group and world size using Fleet API"""
    if not (paddle.distributed.is_initialized() and enable_tp):
        return None, 1
    model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
                            if tp_group is None else tp_group)
    world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
                  if tp_group is None else tp_group.nranks)
Tian Zheng's avatar
Tian Zheng committed
40
41
42
43
44
45
46
47
48
49
50
51
52
    """
    When using TP, the NCCL communication needs to be scheduled
    before the GEMM for a guaranteed overlap. From the host side
    in TE, the comm calls are always launched first, but to ensure
    that the GEMM isn't scheduled first, the environment variable
    `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a
    single channel.
    """
    num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
    if num_cuda_work_queues != 1:
        warnings.warn("To guarantee overlapping TP and SP collectives with the backward"
                      "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1")

53
54
55
56
    return model_parallel_group, world_size


@contextmanager
57
def track_rng_state(enable: bool, **kwargs) -> None:
58
59
60
61
62
    """
    Applies get_rng_state_tracker().rng_state() to the context.
    If not enabled, it does nothing.
    """
    if enable:
63
        with get_rng_state_tracker().rng_state(**kwargs):
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            yield
    else:
        yield


def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None:
    """Set distributed attributes for the input tensor"""
    tensor.is_distributed = is_parallel
    if is_parallel:
        tensor.split_axis = axis


def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
                                parallel_mode: Optional[str], backend: str) -> None:
    """Set distributed attributes for the weight tensor"""
    if not is_parallel or parallel_mode is None:
        return
    set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode])


def allreduce(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
Tian Zheng's avatar
Tian Zheng committed
87
88
    sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
89
90
91
92
93
94
95
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_

    # All-reduce.
Tian Zheng's avatar
Tian Zheng committed
96
97
98
99
100
101
102
103
104
105
    if sync_op:
        output = mp_ops._mp_allreduce(
            input_,
            group=tp_group,
            use_calc_stream=True,
            use_model_parallel=True,
        )
        return output, None

    wait_handle = paddle.distributed.all_reduce(
106
        input_,
Tian Zheng's avatar
Tian Zheng committed
107
        op=paddle.distributed.ReduceOp.SUM,
108
        group=tp_group,
Tian Zheng's avatar
Tian Zheng committed
109
        sync_op=False,
110
111
    )

Tian Zheng's avatar
Tian Zheng committed
112
113
114
    output = input_

    return output, wait_handle
115
116
117
118
119
120
121
122
123
124
125
126
127


def identity(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
    """
    Identity when forward.
    Allreduce across model parallel group when backward.
    """
    output = mp_ops._c_identity(input_, group=tp_group)

    return output