"research/slim/deployment/model_deploy_test.py" did not exist on "afdcf7d4270f91f8b2be38c77f6662c270e8b6ee"
mappings.py 8.14 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

18
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
19
20
21
22
from .utils import split_tensor_along_last_dim


def _reduce(input_):
Taebum Kim's avatar
Taebum Kim committed
23
    """All-reduce the input tensor across model parallel group."""
24
25

    # Bypass the function if we are using only 1 GPU.
26
    if get_tensor_model_parallel_world_size()==1:
27
28
29
        return input_

    # All-reduce.
30
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
31
32
33
34

    return input_


35
def _split_along_last_dim(input_):
36
37
38
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

39
    world_size = get_tensor_model_parallel_world_size()
40
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
    if world_size == 1:
42
43
44
45
46
47
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
48
    rank = get_tensor_model_parallel_rank()
49
50
51
52
    output = input_list[rank].contiguous()

    return output

53

54
55
56
57
58
59
def _split_along_first_dim(input_):
    """Split the tensor along its first dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
60
    if world_size == 1:
61
62
63
64
        return input_

    # Split along first dimension.
    dim_size = input_.size()[0]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
66
    assert dim_size % world_size == 0, \
        "First dimension of the tensor should be divisible by tensor parallel size"
67
68
    local_dim_size = dim_size // world_size
    rank = get_tensor_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
    dim_offset = rank * local_dim_size
70

Vijay Korthikanti's avatar
Vijay Korthikanti committed
71
    output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
72
73

    return output
74

75
76

def _gather_along_last_dim(input_):
77
78
    """Gather tensors and concatinate along the last dimension."""

79
    world_size = get_tensor_model_parallel_world_size()
80
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
81
    if world_size == 1:
82
83
84
85
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
86
    rank = get_tensor_model_parallel_rank()
87
88
89

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
90
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
91
92
93
94
95
96
97

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output


98
99
100
101
102
def _gather_along_first_dim(input_):
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
103
    if world_size == 1:
104
105
106
107
108
109
        return input_

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
110
111
                         device=torch.cuda.current_device())
    torch.distributed._all_gather_base(output, input_.contiguous(),
112
113
114
115
116
117
118
119
                                       group=get_tensor_model_parallel_group())

    return output

def _reduce_scatter_along_first_dim(input_):
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
120
    if world_size == 1:
121
122
123
        return input_

    dim_size = list(input_.size())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
124
125
126
    assert dim_size[0] % world_size == 0, \
        "First dimension of the tensor should be divisible by tensor parallel size"
    
Vijay Korthikanti's avatar
Vijay Korthikanti committed
127
128
    dim_size[0] = dim_size[0] // world_size
   
129
    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
131
                         device=torch.cuda.current_device())
    torch.distributed._reduce_scatter_base(output, input_.contiguous(), 
132
133
134
135
                                           group=get_tensor_model_parallel_group())
    return output


136
137
138
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

139
140
141
142
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
143
144
145
146
147
148
149
150
151
152
    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
Nako Sung's avatar
Nako Sung committed
153
    """All-reduce the input from the model parallel region."""
154

155
156
157
158
    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
159
160
161
162
163
164
165
166
167
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


168
class _ScatterToModelParallelRegion(torch.autograd.Function):
169
170
    """Split the input and keep only the corresponding chuck to the rank."""

171
172
    @staticmethod
    def symbolic(graph, input_):
173
        return _split_along_last_dim(input_)
174

175
176
    @staticmethod
    def forward(ctx, input_):
177
        return _split_along_last_dim(input_)
178
179
180

    @staticmethod
    def backward(ctx, grad_output):
181
        return _gather_along_last_dim(grad_output)
182
183


184
class _GatherFromModelParallelRegion(torch.autograd.Function):
185
186
    """Gather the input from model parallel region and concatinate."""

187
188
    @staticmethod
    def symbolic(graph, input_):
189
        return _gather_along_last_dim(input_)
190
    
191
192
    @staticmethod
    def forward(ctx, input_):
193
        return _gather_along_last_dim(input_)
194
195
196

    @staticmethod
    def backward(ctx, grad_output):
197
        return _split_along_last_dim(grad_output)
198
199


200
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)


216
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    """Gather the input from model parallel region and concatinate.""" #TODO

    @staticmethod
    def symbolic(graph, input_):
        return _gather_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _gather_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce_scatter_along_first_dim(grad_output)


232
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
233
234
235
236
237
238
239
240
241
242
243
244
245
    """Reduce scatter the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce_scatter_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce_scatter_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)
246
247
248
249
250
251


# -----------------
# Helper functions.
# -----------------

252
def copy_to_tensor_model_parallel_region(input_):
253
254
    return _CopyToModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
255

256
def reduce_from_tensor_model_parallel_region(input_):
257
258
    return _ReduceFromModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
259

260
261
def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)
262
263


264
265
def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)
266
267


268
269
def scatter_to_sequence_parallel_region(input_):
    return _ScatterToSequenceParallelRegion.apply(input_)
270
271


272
def gather_from_sequence_parallel_region(input_):
273
    return _GatherFromSequenceParallelRegion.apply(input_)
274

275

276
277
def reduce_scatter_to_sequence_parallel_region(input_):
    return _ReduceScatterToSequenceParallelRegion.apply(input_)
Neel Kant's avatar
Neel Kant committed
278