communication.py 6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
working  
mshoeybi committed
2
3
4
5
6
7

"""Communications utilities."""


import torch

mshoeybi's avatar
mshoeybi committed
8
9
10
from megatron import mpu


mshoeybi's avatar
mshoeybi committed
11

mshoeybi's avatar
mshoeybi committed
12
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not mpu.is_pipeline_first_stage():
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            mpu.get_pipeline_model_parallel_prev_rank())
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
29
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()



mshoeybi's avatar
mshoeybi committed
59
60
61
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

mshoeybi's avatar
mshoeybi committed
62
63
64
65
66
67
68
69
    is_last_stage = mpu.is_pipeline_last_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if mpu.is_pipeline_first_stage() and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
mshoeybi committed
70
71
72
73
74
75
76
77
78
79
80
81
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = mpu.get_pipeline_model_parallel_last_rank()
    group = mpu.get_pipeline_model_parallel_group()
    torch.distributed.broadcast(tensor, src, group)

    return tensor


mshoeybi's avatar
working  
mshoeybi committed
82
83
84
85
86
87

def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
88
89
90
91
92
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
working  
mshoeybi committed
93
94
    if is_last_stage or is_first_stage:
        if is_last_stage:
mshoeybi's avatar
mshoeybi committed
95
            _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src, group)
    else:
        tensor = None

    return tensor



mshoeybi's avatar
mshoeybi committed
111
112
113
114
115
116
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
117
118
119
120
121
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
mshoeybi committed
122
    if is_last_stage or is_first_stage:
mshoeybi's avatar
mshoeybi committed
123
        _is_cuda(tensor)
mshoeybi's avatar
working  
mshoeybi committed
124
        is_contiguous = tensor.is_contiguous()
mshoeybi's avatar
mshoeybi committed
125
126
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
mshoeybi's avatar
working  
mshoeybi committed
127
128
        if is_contiguous:
            tensor_ = tensor
mshoeybi's avatar
mshoeybi committed
129
        else:
mshoeybi's avatar
working  
mshoeybi committed
130
131
132
133
134
135
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
136
137
138
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
mshoeybi's avatar
working  
mshoeybi committed
139
        if is_first_stage and not is_contiguous:
mshoeybi's avatar
mshoeybi committed
140
            tensor[...] = tensor_
mshoeybi's avatar
working  
mshoeybi committed
141
142


mshoeybi's avatar
mshoeybi committed
143

mshoeybi's avatar
working  
mshoeybi committed
144
145
146
147
148
149
def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """ Given size and type of a tensor on all ranks and the tensor value
        only on a specific rank, broadcast from that rank to all other ranks.
    """

    if torch.distributed.get_rank() == rank:
mshoeybi's avatar
mshoeybi committed
150
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
151
152
153
154
155
156
157
158
159
160
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())

    torch.distributed.broadcast(tensor, rank)

    return tensor


mshoeybi's avatar
mshoeybi committed
161

mshoeybi's avatar
mshoeybi committed
162
163
164
165
166
167
168
169
170
171
172
def broadcast_list(size, dtype, list_values=None, rank=0):
    """Broadcast a list of values with a given type."""

    tensor = None
    if torch.distributed.get_rank() == rank:
        tensor = torch.tensor(list_values, dtype=dtype,
                              device=torch.cuda.current_device())

    return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)


mshoeybi's avatar
mshoeybi committed
173

mshoeybi's avatar
working  
mshoeybi committed
174
175
176
def broadcast_int_list(size, int_list=None, rank=0):
    """Broadcast a list of interger values."""

mshoeybi's avatar
mshoeybi committed
177
178
179
    return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)


mshoeybi's avatar
mshoeybi committed
180

mshoeybi's avatar
mshoeybi committed
181
182
def broadcast_float_list(size, float_list=None, rank=0):
    """Broadcast a list of float values."""
mshoeybi's avatar
working  
mshoeybi committed
183

mshoeybi's avatar
mshoeybi committed
184
185
    return broadcast_list(size, torch.float32, list_values=float_list,
                          rank=rank)