merge_mp_partitions.py 14 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Merge model parallel partitions."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
17
18

import os
19
import re
Mohammad's avatar
Mohammad committed
20
21
22
23
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
25
26
import torch

from megatron import mpu
Jared Casper's avatar
Jared Casper committed
27
from megatron.checkpointing import load_checkpoint, save_checkpoint
Mohammad's avatar
Mohammad committed
28
29
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
Jared Casper's avatar
Jared Casper committed
30
from megatron.checkpointing import get_checkpoint_version
Mohammad's avatar
Mohammad committed
31
from megatron.checkpointing import get_checkpoint_tracker_filename
Jared Casper's avatar
Jared Casper committed
32
from megatron.global_vars import set_global_variables, get_args
Mohammad's avatar
Mohammad committed
33
from megatron.global_vars import rebuild_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def split_into_partitions(tensor, num_partitions, partition_dim, stride):

    per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
                                          num_partitions)
    per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)

    partitions_list = torch.split(tensor,
                                  per_partition_per_stride_size,
                                  dim=partition_dim)

    partitions = []
    for i in range(num_partitions):
        partition = torch.cat(partitions_list[i::num_partitions],
                              dim=partition_dim)
        partitions.append(partition)

    return partitions


def merge_partitions(merged, partitions, partition_dim, stride):

    # Number and size of each partition.
    num_partitions = len(partitions)
    per_partition_size = None
    for partition in partitions:
        if per_partition_size is None:
            per_partition_size = partition.size(partition_dim)
        else:
            assert per_partition_size == partition.size(partition_dim)

    def concat_partitions(partitions_):
        with torch.no_grad():
            if (per_partition_size * num_partitions) == merged.size(
                    partition_dim):
                torch.cat(partitions_, dim=partition_dim, out=merged)
            else:
                print('     ***WARNING*** sizes do not match. Will cut '
                      'the merged partitions by {} along dimension {} '
                      'to reduce the size from {} to {} ...'.format(
                          (per_partition_size * num_partitions) - \
                          merged.size(partition_dim), partition_dim,
                          per_partition_size * num_partitions,
                          merged.size(partition_dim)))
                merged_ = torch.cat(partitions_, dim=partition_dim)
                merged_split = torch.split(merged_, merged.size(partition_dim),
                                           dim=partition_dim)
                merged_ = merged_split[0]
                assert merged_.size(partition_dim) == merged.size(partition_dim)
                merged.data.copy_(merged_.data)

    # If stride is 1, then do simple concatination.
    if stride == 1:
        concat_partitions(partitions)
        return

    # For none unity strides, first split based on stride and then group.
    per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
    # Chunk and build a list.
    chunks = None
    for i, partition in enumerate(partitions):
        chunk = torch.split(partition,
                            per_partition_per_stride_size,
                            dim=partition_dim)

        if chunks is None:
            chunks = [0]*(num_partitions*len(chunk))
        chunks[i::num_partitions] = chunk

    # Concatinate.
    concat_partitions(chunks)

    return


Mohammad's avatar
Mohammad committed
110
def get_model(model_type):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
111
112

    if model_type == 'BERT':
Mohammad's avatar
Mohammad committed
113
        from pretrain_bert import model_provider
114
115
    elif model_type == 'GPT':
        from pretrain_gpt import model_provider
Mohammad's avatar
Mohammad committed
116
117
118
119
120
121
122
123
124
    elif model_type == 'RACE':
        from tasks.race.finetune import model_provider
    elif model_type == ['MNLI', 'QQP']:
        num_classes = 2
        if model_type == 'MNLI':
            num_classes = 3
        from megatron.model.classification import Classification
        def model_provider():
            return Classification(num_classes=num_classes, num_tokentypes=2)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
125
126
127
    else:
        raise Exception('unrecognized model type: {}'.format(model_type))

Mohammad's avatar
Mohammad committed
128
    model = model_provider()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    model = model.half()

    return model


def get_parallel_checkpoint_name(path):

    tracker_filename = get_checkpoint_tracker_filename(path)
    iteration = 0
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        iteration = int(metastring)
    assert iteration > 0
    checkpoint_name = get_checkpoint_name(path, iteration)

    return checkpoint_name, iteration


def test_split_merge():

    print('testing split and merge ...')

    #[QKV.ROW-COL]
    tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
                                [1.21, 1.22, 1.23, 1.24, 1.25],
                                [1.31, 1.32, 1.33, 1.34, 1.35],
                                [1.41, 1.42, 1.43, 1.44, 1.45],
                                [2.11, 2.12, 2.13, 2.14, 2.15],
                                [2.21, 2.22, 2.23, 2.24, 2.25],
                                [2.31, 2.32, 2.33, 2.34, 2.35],
                                [2.41, 2.42, 2.43, 2.44, 2.45],
                                [3.11, 3.12, 3.13, 3.14, 3.15],
                                [3.21, 3.22, 3.23, 3.24, 3.25],
                                [3.31, 3.32, 3.33, 3.34, 3.35],
                                [3.41, 3.42, 3.43, 3.44, 3.45]])

    num_partitions = 2
    partition_dim = 0
    stride = 3
    partitions = split_into_partitions(tensor, num_partitions,
                                       partition_dim, stride)

    merged = torch.zeros_like(tensor)
    merge_partitions(merged, partitions, partition_dim, stride)

    max_error = (merged - tensor).abs().max()
    print('  > max error (should be zero): {}'.format(max_error))


Mohammad's avatar
Mohammad committed
178
179
180
181
182
def get_mp_merge_args(parser):
    """Provide extra arguments required for merging."""
    group = parser.add_argument_group(title='mp merge')

    group.add_argument('--model-type', type=str, required=True,
183
                       choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
Mohammad's avatar
Mohammad committed
184
                       help='Type of the mdoel.')
185
186
    group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism in output model.')
Mohammad's avatar
Mohammad committed
187
188
189
190
191

    return parser


def main():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
192

Jared Casper's avatar
Jared Casper committed
193
194
195
196
    # Arguments do sanity checks on the world size, but we don't care,
    # so trick it into thinking we are plenty of processes
    os.environ["WORLD_SIZE"] = f'{2**31}'

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
197
    # Args
Jared Casper's avatar
Jared Casper committed
198
199
200
201
202
    set_global_variables(extra_args_provider=get_mp_merge_args,
                         args_defaults = {'use_cpu_initialization': True,
                                          'micro_batch_size': 1,
                                          'no_load_optim': True,
                                          'no_load_rng': True,
203
204
                                          'no_save_optim': True,
                                          'no_save_rng': True,
Jared Casper's avatar
Jared Casper committed
205
206
                                          'save_interval': 1})
    args = get_args()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
207

Jared Casper's avatar
Jared Casper committed
208
209
210
211
    if args.pipeline_model_parallel_size > 1:
        print("Checkpoints with pipeline model parallelism are not currently supported.")
        exit()

212
213
214
215
216
    model_type = args.model_type
    orig_tensor_model_parallel_size = args.tensor_model_parallel_size
    args.tensor_model_parallel_size = 1
    tokenizer = rebuild_tokenizer(args)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
217
    print('\n merging model parallel partitions ...')
218
    print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
219
220
    print(' > checkpoint path: {}'.format(args.load))
    print(' > model parameters:')
Mohammad's avatar
Mohammad committed
221
222
    print('    number of tokens ................ {} '.format(
        tokenizer.vocab_size))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
223
    print('    number of layers ................ {}'.format(args.num_layers))
224
    print('    hidden size ..................... {}'.format(args.hidden_size))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
225
226
227
228
229
230
231
    print('    number of attention heads ....... {}'.format(
        args.num_attention_heads))
    print('    maximum position embeddings ..... {}'.format(
        args.max_position_embeddings))

    # Full model.
    print('> building the full model ...')
232
233
    mpu.initialize.set_tensor_model_parallel_world_size(1)
    mpu.initialize.set_tensor_model_parallel_rank(0)
Jared Casper's avatar
Jared Casper committed
234
235
    mpu.initialize.set_pipeline_model_parallel_world_size(1)
    mpu.initialize.set_pipeline_model_parallel_rank(0)
Mohammad's avatar
Mohammad committed
236
    merged_model = get_model(model_type)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
237
238
239
240

    # Build and load partitions.
    partitions = []
    iteration = 0
241
    args.tensor_model_parallel_size = orig_tensor_model_parallel_size
Mohammad's avatar
Mohammad committed
242
    tokenizer = rebuild_tokenizer(args)
243
244
    mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
    for rank in range(args.tensor_model_parallel_size):
245
246
247
248
249
        # Reset these since load_checkpoint asserts they are 0, but we are loading
        # multiple checkpoints in the same process and they get set each time
        args.consumed_train_samples = 0
        args.consumed_valid_samples = 0

250
        mpu.initialize.set_tensor_model_parallel_rank(rank)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
251
        checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
Mohammad's avatar
Mohammad committed
252
        model_ = get_model(model_type)
Jared Casper's avatar
Jared Casper committed
253
254
255
        print(f'> loading {checkpoint_name} ...')
        load_checkpoint(model_, None, None)
        print(f'> checkpoint version {get_checkpoint_version()}')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        partitions.append(model_)

    # Parameter generators so we can loop through them semiltaneouly.
    merged_params_gen = merged_model.named_parameters()
    partitions_params_gen = [partition.named_parameters()
                             for partition in partitions]
    while True:
        try:

            # Get the params and check names.
            name, merged_param = next(merged_params_gen)
            print(' > working on {} ...'.format(name))
            print('     merged         type: {}, size: {}'.format(
                merged_param.dtype, list(merged_param.size())))
            partitions_param = []
            for rank, partition_params_gen in enumerate(partitions_params_gen):
                partition_name, partition_param = next(partition_params_gen)
                assert partition_name == name
                partitions_param.append(partition_param)
                print('     partition {}    type: {}, size: {}'.format(
                    rank, partition_param.dtype, list(partition_param.size())))

            # For the non-parallel parameters, simply copy the rank 0 values.
279
            if not hasattr(merged_param, 'tensor_model_parallel'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
280
281
282
283
284
                print('     none-parallel parameter, simple copy from rank 0')
                with torch.no_grad():
                    merged_param.data.copy_(partitions_param[0].data)
            # For parallel parameters, merge the values
            else:
Jared Casper's avatar
Jared Casper committed
285
286
287
288
                dim = merged_param.partition_dim
                stride = merged_param.partition_stride
                print(f'     parallel parameter merge with stride {stride} along '
                      f'dimention {dim}')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
289
290
                merge_partitions(merged_param,
                                 partitions_param,
Jared Casper's avatar
Jared Casper committed
291
292
                                 dim,
                                 stride)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
293
294
295
296

        except StopIteration:
            break

297
    partitions = []
298
    args.tensor_model_parallel_size = 1
299
300
301
302
303
304
    args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size

    assert args.num_layers % args.pipeline_model_parallel_size == 0, \
        'num_layers must be divisible by target pipeline model parallel size'
    layers_per_part = args.num_layers // args.pipeline_model_parallel_size

Jared Casper's avatar
Jared Casper committed
305
    tokenizer = rebuild_tokenizer(args)
306
    mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
307
    mpu.initialize.set_tensor_model_parallel_rank(0)
308
309
    mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)

310
    # regex to parse out layer number from param name
311
312
313
314
315
316
317
318
319
320
    layer_re = re.compile('layers\.([0-9]+)')

    if args.pipeline_model_parallel_size > 1:
        merged_params = {}
        for name, merged_param in merged_model.named_parameters():
            merged_params[name] = merged_param

        for rank in range(args.pipeline_model_parallel_size):
            mpu.initialize.set_pipeline_model_parallel_rank(rank)
            model = get_model(model_type)
321
            def update_layer_num(m):
322
323
324
325
326
327
328
329
330
331
                # TODO! This assumes no interleaved pipeline execution
                layer = int(m.group(1))
                layer += rank * layers_per_part
                return f'layers.{layer}'

            for dst_name, partition_param in model.named_parameters():
                if dst_name == "word_embeddings.weight":
                    # See comment in MegatronModule.initialize_word_embeddings()
                    src_name = "language_model.embedding.word_embeddings.weight"
                else:
332
333
334
335
                    # Translate destination layer number (0-N for each partition)
                    # to source layer number (single-model layer number)
                    src_name = re.sub(layer_re, update_layer_num, dst_name)
                print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
336
337
338
339
340
341
342
343
344
345
                partition_param.data.copy_(merged_params[src_name].data)

            partitions.append(model)
    else:
        partitions = [merged_model]

    for rank, model in enumerate(partitions):
        mpu.initialize.set_pipeline_model_parallel_rank(rank)
        print(f"> saving rank {rank}'s model")
        save_checkpoint(iteration, model, None, None)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
346
347
348
349
350
351

    print('done :-)')


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
352
    main()