distributed.py 5.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Methods needed for distributed training."""

Tian Zheng's avatar
Tian Zheng committed
6
7
import os
import warnings
8
from contextlib import contextmanager
Tian Zheng's avatar
Tian Zheng committed
9
from typing import Any, Optional, Union, Tuple
10
11
12
13
14
15
16
17
18
19

import paddle

import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops

from .constants import dist_group_type

_weight_split_axis = {
20
21
    "transformer_engine": {"row": 1, "column": 0},
    "paddle": {"row": 0, "column": 1},
22
23
24
}


25
26
27
def get_tp_group_and_world_size(
    tp_group: Union[dist_group_type, None], enable_tp: bool = True
) -> Tuple[Union[dist_group_type, None], int]:
28
29
30
    """Get TP group and world size using Fleet API"""
    if not (paddle.distributed.is_initialized() and enable_tp):
        return None, 1
31
32
33
34
35
36
37
38
    model_parallel_group = (
        tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group
    )
    world_size = (
        tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
        if tp_group is None
        else tp_group.nranks
    )
Tian Zheng's avatar
Tian Zheng committed
39
40
41
42
43
44
45
46
47
48
    """
    When using TP, the NCCL communication needs to be scheduled
    before the GEMM for a guaranteed overlap. From the host side
    in TE, the comm calls are always launched first, but to ensure
    that the GEMM isn't scheduled first, the environment variable
    `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a
    single channel.
    """
    num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
    if num_cuda_work_queues != 1:
49
50
51
52
        warnings.warn(
            "To guarantee overlapping TP and SP collectives with the backward"
            "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
        )
Tian Zheng's avatar
Tian Zheng committed
53

54
55
56
57
    return model_parallel_group, world_size


@contextmanager
58
def track_rng_state(enable: bool, **kwargs) -> None:
59
60
61
62
63
    """
    Applies get_rng_state_tracker().rng_state() to the context.
    If not enabled, it does nothing.
    """
    if enable:
64
        with get_rng_state_tracker().rng_state(**kwargs):
65
66
67
68
69
70
71
72
73
74
75
76
            yield
    else:
        yield


def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None:
    """Set distributed attributes for the input tensor"""
    tensor.is_distributed = is_parallel
    if is_parallel:
        tensor.split_axis = axis


77
78
79
def set_weight_tensor_dist_attr(
    tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str
) -> None:
80
81
82
83
84
85
86
87
88
    """Set distributed attributes for the weight tensor"""
    if not is_parallel or parallel_mode is None:
        return
    set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode])


def allreduce(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
Tian Zheng's avatar
Tian Zheng committed
89
90
    sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
91
92
93
94
95
96
97
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_

    # All-reduce.
Tian Zheng's avatar
Tian Zheng committed
98
99
100
101
102
103
104
105
106
107
    if sync_op:
        output = mp_ops._mp_allreduce(
            input_,
            group=tp_group,
            use_calc_stream=True,
            use_model_parallel=True,
        )
        return output, None

    wait_handle = paddle.distributed.all_reduce(
108
        input_,
Tian Zheng's avatar
Tian Zheng committed
109
        op=paddle.distributed.ReduceOp.SUM,
110
        group=tp_group,
Tian Zheng's avatar
Tian Zheng committed
111
        sync_op=False,
112
113
    )

Tian Zheng's avatar
Tian Zheng committed
114
115
116
    output = input_

    return output, wait_handle
117
118


119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def allgather(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
    sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
    """All-gather the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_, None

    parallelism = tp_group.nranks
    output_shape = input_.shape
    output_shape[0] = output_shape[0] * parallelism
    output = paddle.empty(shape=output_shape, dtype=input_.dtype)
    wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
    if sync_op:
        wait_handle.wait()
        return output, None
    return output, wait_handle


def reduce_scatter(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
    sync_op: bool = True,
) -> [paddle.Tensor, Any]:
    """Reduce-scatter the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_, None

    parallelism = tp_group.nranks
    output_shape = input_.shape
154
155
    assert input_.shape[0] % parallelism == 0, (
        f"Input sequence length {input_.shape[0]} can't be divided "
156
        f"exactly by sequence parallelism {parallelism}"
157
    )
158
159
    output_shape[0] = output_shape[0] // parallelism
    output = paddle.empty(shape=output_shape, dtype=input_.dtype)
160
161
162
    wait_handle = paddle.distributed.stream.reduce_scatter(
        output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op
    )
163
164
165
166
167
    if sync_op:
        return output, None
    return output, wait_handle


168
169
170
171
172
173
174
175
176
177
178
def identity(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
    """
    Identity when forward.
    Allreduce across model parallel group when backward.
    """
    output = mp_ops._c_identity(input_, group=tp_group)

    return output
179
180
181
182
183
184
185
186


def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor):
    """
    Set sequence_parallel attribute to input tensor. It is used for registering allreduce
    hooks in PaddleNLP sequence parallel training.
    """
    setattr(parameter, "sequence_parallel", True)