distributed.py 6.61 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Methods needed for distributed training."""

Tian Zheng's avatar
Tian Zheng committed
6
7
import os
import warnings
8
from contextlib import contextmanager
Tian Zheng's avatar
Tian Zheng committed
9
from typing import Any, Optional, Union, Tuple
10
11
12
13
14
15
16

import paddle

import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops

17
18
19
20
21
22
23
24
25
26
try:
    # This feature is not supported as of Paddle 2.6.
    from paddle.distributed.fleet.meta_parallel import (
        PipelineParallelMicroStepLocations,
        register_global_pipeline_parallel_hook,
    )
except ImportError:
    print("Cannot find register_global_pipeline_parallel_hook !")
    register_global_pipeline_parallel_hook = None

27
28
29
from .constants import dist_group_type

_weight_split_axis = {
30
31
    "transformer_engine": {"row": 1, "column": 0},
    "paddle": {"row": 0, "column": 1},
32
33
34
}


35
36
37
def get_tp_group_and_world_size(
    tp_group: Union[dist_group_type, None], enable_tp: bool = True
) -> Tuple[Union[dist_group_type, None], int]:
38
39
40
    """Get TP group and world size using Fleet API"""
    if not (paddle.distributed.is_initialized() and enable_tp):
        return None, 1
41
42
43
44
45
46
47
48
    model_parallel_group = (
        tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group
    )
    world_size = (
        tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
        if tp_group is None
        else tp_group.nranks
    )
Tian Zheng's avatar
Tian Zheng committed
49
50
51
52
53
54
55
56
57
58
    """
    When using TP, the NCCL communication needs to be scheduled
    before the GEMM for a guaranteed overlap. From the host side
    in TE, the comm calls are always launched first, but to ensure
    that the GEMM isn't scheduled first, the environment variable
    `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a
    single channel.
    """
    num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
    if num_cuda_work_queues != 1:
59
60
61
62
        warnings.warn(
            "To guarantee overlapping TP and SP collectives with the backward"
            "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
        )
Tian Zheng's avatar
Tian Zheng committed
63

64
65
66
    return model_parallel_group, world_size


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def is_pp_enabled() -> bool:
    """Check if pipeline parallel is enabled"""
    if not paddle.distributed.is_initialized():
        return False

    return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1


def register_pp_fwd_begin_hook(forward_begin_hook):
    """Register the pp hook if register_global_pipeline_parallel_hook exist"""
    if register_global_pipeline_parallel_hook is not None:
        register_global_pipeline_parallel_hook(
            PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook
        )


83
@contextmanager
84
def track_rng_state(enable: bool, **kwargs) -> None:
85
86
87
88
89
    """
    Applies get_rng_state_tracker().rng_state() to the context.
    If not enabled, it does nothing.
    """
    if enable:
90
        with get_rng_state_tracker().rng_state(**kwargs):
91
92
93
94
95
96
97
98
99
100
101
102
            yield
    else:
        yield


def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None:
    """Set distributed attributes for the input tensor"""
    tensor.is_distributed = is_parallel
    if is_parallel:
        tensor.split_axis = axis


103
104
105
def set_weight_tensor_dist_attr(
    tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str
) -> None:
106
107
108
109
110
111
112
113
114
    """Set distributed attributes for the weight tensor"""
    if not is_parallel or parallel_mode is None:
        return
    set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode])


def allreduce(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
Tian Zheng's avatar
Tian Zheng committed
115
116
    sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
117
118
119
120
121
122
123
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_

    # All-reduce.
Tian Zheng's avatar
Tian Zheng committed
124
125
126
127
128
129
130
131
132
133
    if sync_op:
        output = mp_ops._mp_allreduce(
            input_,
            group=tp_group,
            use_calc_stream=True,
            use_model_parallel=True,
        )
        return output, None

    wait_handle = paddle.distributed.all_reduce(
134
        input_,
Tian Zheng's avatar
Tian Zheng committed
135
        op=paddle.distributed.ReduceOp.SUM,
136
        group=tp_group,
Tian Zheng's avatar
Tian Zheng committed
137
        sync_op=False,
138
139
    )

Tian Zheng's avatar
Tian Zheng committed
140
141
142
    output = input_

    return output, wait_handle
143
144


145
146
147
148
def allgather(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
    sync_op: bool = True,
149
    axis: int = 0,
150
151
152
153
154
155
156
157
158
) -> Tuple[paddle.Tensor, Any]:
    """All-gather the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_, None

    parallelism = tp_group.nranks
    output_shape = input_.shape
159
    output_shape[axis] = output_shape[axis] * parallelism
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    output = paddle.empty(shape=output_shape, dtype=input_.dtype)
    wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
    if sync_op:
        wait_handle.wait()
        return output, None
    return output, wait_handle


def reduce_scatter(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
    sync_op: bool = True,
) -> [paddle.Tensor, Any]:
    """Reduce-scatter the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if tp_group is None or tp_group.nranks == 1:
        return input_, None

    parallelism = tp_group.nranks
    output_shape = input_.shape
181
182
    assert input_.shape[0] % parallelism == 0, (
        f"Input sequence length {input_.shape[0]} can't be divided "
183
        f"exactly by sequence parallelism {parallelism}"
184
    )
185
186
    output_shape[0] = output_shape[0] // parallelism
    output = paddle.empty(shape=output_shape, dtype=input_.dtype)
187
188
189
    wait_handle = paddle.distributed.stream.reduce_scatter(
        output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op
    )
190
191
192
193
194
    if sync_op:
        return output, None
    return output, wait_handle


195
196
197
198
199
200
201
202
203
204
205
def identity(
    input_: paddle.Tensor,
    tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
    """
    Identity when forward.
    Allreduce across model parallel group when backward.
    """
    output = mp_ops._c_identity(input_, group=tp_group)

    return output
206
207
208
209
210
211
212
213


def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor):
    """
    Set sequence_parallel attribute to input tensor. It is used for registering allreduce
    hooks in PaddleNLP sequence parallel training.
    """
    setattr(parameter, "sequence_parallel", True)