module.py 3.91 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron Module"""

import torch

20
21
22
from megatron import get_args
from megatron import mpu

23
24

class MegatronModule(torch.nn.Module):
25
    """Megatron specific extensions of torch Module."""
26
27
28
29
30
31
32
33
34

    def __init__(self):
        super(MegatronModule, self).__init__()

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """Use this function to override the state dict for
        saving checkpoints."""
        return self.state_dict(destination, prefix, keep_vars)
35
36
37
38
39


class PipelinedMegatronModule(MegatronModule):
    """Pipelining specific extensions of MegatronModule."""

40
    def __init__(self, share_word_embeddings=True):
41
        super(PipelinedMegatronModule, self).__init__()
42
43
        args = get_args()
        self.share_word_embeddings = share_word_embeddings
44
45
46
47
48

    def word_embeddings_weight(self):
        if mpu.is_pipeline_first_stage():
            return self.language_model.embedding.word_embeddings.weight
        if mpu.is_pipeline_last_stage():
49
50
51
            if not self.share_word_embeddings:
                raise Exception('word_embeddings_weight() called for last stage, '
                                'but share_word_embeddings is false')
52
53
54
55
56
57
            return self.word_embeddings.weight
        raise Exception('word_embeddings_weight() should be '
                        'called for first and last stage only')

    def initialize_word_embeddings(self, init_method_normal):
        args = get_args()
58
59
60
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        # Parameters are shared between the word embeddings layer, and the heads at
        # the end of the model. In a pipelined setup with more than one stage, the
        # initial embedding layer and the head are on different workers, so we do
        # the following:
        # 1. Create a second copy of word_embeddings on the last stage, with initial
        #    parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that the
        #    two copies of word_embeddings start off with the same parameter values.
        # 3. In the training loop, before an all-reduce between the grads of the two
        #    word_embeddings layers to ensure that every applied weight update is the
        #    same on both stages.
        if mpu.is_pipeline_last_stage():
            if not mpu.is_pipeline_first_stage():
                self._word_embeddings_for_head_key = 'word_embeddings_for_head'
                # If first and last stages are different, set word_embeddings
                # weights to 0 here, then copy first stage's weights using all_reduce
                # below.
                self.word_embeddings = mpu.VocabParallelEmbedding(
                    args.padded_vocab_size, args.hidden_size,
                    init_method=init_method_normal(args.init_method_std))
                self.word_embeddings.weight.data.fill_(0)
        # Ensure that first and last stages have the same initial parameter values.
        if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
            torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                         group=mpu.get_embedding_group())