communication.py 6.54 KB
Newer Older
mshoeybi's avatar
working  
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Communications utilities."""


import torch

mshoeybi's avatar
mshoeybi committed
21
22
23
from megatron import mpu


mshoeybi's avatar
mshoeybi committed
24

mshoeybi's avatar
mshoeybi committed
25
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not mpu.is_pipeline_first_stage():
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            mpu.get_pipeline_model_parallel_prev_rank())
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
42
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()



mshoeybi's avatar
mshoeybi committed
72
73
74
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

mshoeybi's avatar
mshoeybi committed
75
76
77
78
79
80
81
82
    is_last_stage = mpu.is_pipeline_last_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if mpu.is_pipeline_first_stage() and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
mshoeybi committed
83
84
85
86
87
88
89
90
91
92
93
94
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = mpu.get_pipeline_model_parallel_last_rank()
    group = mpu.get_pipeline_model_parallel_group()
    torch.distributed.broadcast(tensor, src, group)

    return tensor


mshoeybi's avatar
working  
mshoeybi committed
95
96
97
98
99
100

def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
101
102
103
104
105
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
working  
mshoeybi committed
106
107
    if is_last_stage or is_first_stage:
        if is_last_stage:
mshoeybi's avatar
mshoeybi committed
108
            _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src, group)
    else:
        tensor = None

    return tensor



mshoeybi's avatar
mshoeybi committed
124
125
126
127
128
129
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
130
131
132
133
134
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
mshoeybi committed
135
    if is_last_stage or is_first_stage:
mshoeybi's avatar
mshoeybi committed
136
        _is_cuda(tensor)
mshoeybi's avatar
working  
mshoeybi committed
137
        is_contiguous = tensor.is_contiguous()
mshoeybi's avatar
mshoeybi committed
138
139
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
mshoeybi's avatar
working  
mshoeybi committed
140
141
        if is_contiguous:
            tensor_ = tensor
mshoeybi's avatar
mshoeybi committed
142
        else:
mshoeybi's avatar
working  
mshoeybi committed
143
144
145
146
147
148
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
149
150
151
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
mshoeybi's avatar
working  
mshoeybi committed
152
        if is_first_stage and not is_contiguous:
mshoeybi's avatar
mshoeybi committed
153
            tensor[...] = tensor_
mshoeybi's avatar
working  
mshoeybi committed
154
155


mshoeybi's avatar
mshoeybi committed
156

mshoeybi's avatar
working  
mshoeybi committed
157
158
159
160
161
162
def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """ Given size and type of a tensor on all ranks and the tensor value
        only on a specific rank, broadcast from that rank to all other ranks.
    """

    if torch.distributed.get_rank() == rank:
mshoeybi's avatar
mshoeybi committed
163
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
164
165
166
167
168
169
170
171
172
173
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())

    torch.distributed.broadcast(tensor, rank)

    return tensor


mshoeybi's avatar
mshoeybi committed
174

mshoeybi's avatar
mshoeybi committed
175
176
177
178
179
180
181
182
183
184
185
def broadcast_list(size, dtype, list_values=None, rank=0):
    """Broadcast a list of values with a given type."""

    tensor = None
    if torch.distributed.get_rank() == rank:
        tensor = torch.tensor(list_values, dtype=dtype,
                              device=torch.cuda.current_device())

    return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)


mshoeybi's avatar
mshoeybi committed
186

mshoeybi's avatar
working  
mshoeybi committed
187
188
189
def broadcast_int_list(size, int_list=None, rank=0):
    """Broadcast a list of interger values."""

mshoeybi's avatar
mshoeybi committed
190
191
192
    return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)


mshoeybi's avatar
mshoeybi committed
193

mshoeybi's avatar
mshoeybi committed
194
195
def broadcast_float_list(size, float_list=None, rank=0):
    """Broadcast a list of float values."""
mshoeybi's avatar
working  
mshoeybi committed
196

mshoeybi's avatar
mshoeybi committed
197
198
    return broadcast_list(size, torch.float32, list_values=float_list,
                          rank=rank)