module.py 7.31 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron Module"""

import torch
19
20
from torch.autograd import Variable
from torch.nn.parameter import Parameter
21

22
23
24
from megatron import get_args
from megatron import mpu

25

26
27
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
29
30


mohammad's avatar
mohammad committed
31
32
33
34
35
36

def param_is_not_shared(param):
    return not hasattr(param, 'shared') or not param.shared



37
class MegatronModule(torch.nn.Module):
38
39
    """Megatron specific extensions of torch Module with support
    for pipelining."""
40

41
    def __init__(self, share_word_embeddings=True):
42
        super(MegatronModule, self).__init__()
43
44
        self.share_word_embeddings = share_word_embeddings

45
46
47
48
49
50

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """Use this function to override the state dict for
        saving checkpoints."""
        return self.state_dict(destination, prefix, keep_vars)
51
52
53


    def word_embeddings_weight(self):
54
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
55
            return self.language_model.embedding.word_embeddings.weight
56
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
57
            if not self.share_word_embeddings:
58
59
                raise Exception('word_embeddings_weight() called for last '
                                'stage, but share_word_embeddings is false')
60
61
62
63
            return self.word_embeddings.weight
        raise Exception('word_embeddings_weight() should be '
                        'called for first and last stage only')

64

65
66
    def initialize_word_embeddings(self, init_method_normal):
        args = get_args()
67
68
69
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')
70
71
72
73

        # This function just initializes the word embeddings in the final stage
        # when we are using pipeline parallelism. If we aren't using pipeline
        # parallelism there is nothing to do.
Jared Casper's avatar
Jared Casper committed
74
75
        if args.pipeline_model_parallel_size == 1:
            return
76

77
78
79
80
81
82
83
84
85
86
87
88
        # Parameters are shared between the word embeddings layer, and the
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
89
        if mpu.is_pipeline_last_stage():
90
91
92
93
94
95
96
97
98
99
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
            self.word_embeddings = mpu.VocabParallelEmbedding(
                args.padded_vocab_size, args.hidden_size,
                init_method=init_method_normal(args.init_method_std))
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True

100
101
        # Ensure that first and last stages have the same initial parameter
        # values.
102
103
104
105
106
107
108
109
110
111
        if torch.distributed.is_initialized():
            if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
                torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                             group=mpu.get_embedding_group())
        else:
            print("WARNING! Distributed processes aren't initialized, so "
                  "word embeddings in the last layer are not initialized. "
                  "If you are just manipulating a model this is fine, but "
                  "this needs to be handled manually. If you are training "
                  "something is definitely wrong.")
112

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
113

114
115
116
117
118
119
120
121
122
123
124
def conversion_helper(val, conversion):
    """Apply conversion to val. Recursively apply conversion if `val`
    #is a nested tuple/list structure."""
    if not isinstance(val, (tuple, list)):
        return conversion(val)
    rtn = [conversion_helper(v, conversion) for v in val]
    if isinstance(val, tuple):
        rtn = tuple(rtn)
    return rtn


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
125
126
def fp32_to_float16(val, float16_convertor):
    """Convert fp32 `val` to fp16/bf16"""
127
128
129
130
131
    def half_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, _FLOAT_TYPES):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
132
            val = float16_convertor(val)
133
134
135
136
        return val
    return conversion_helper(val, half_conversion)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
137
138
def float16_to_fp32(val):
    """Convert fp16/bf16 `val` to fp32"""
139
140
141
142
    def float_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143
        if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
144
145
146
147
148
149
            val = val.float()
        return val
    return conversion_helper(val, float_conversion)



Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class Float16Module(MegatronModule):

    def __init__(self, module, args):
        super(Float16Module, self).__init__()

        if args.fp16:
            self.add_module('module', module.half())
            def float16_convertor(val):
                return val.half()
        elif args.bf16:
            self.add_module('module', module.bfloat16())
            def float16_convertor(val):
                return val.bfloat16()
        else:
            raise Exception('should not be here')
165

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
166
        self.float16_convertor = float16_convertor
167
168
169
170


    def forward(self, *inputs, **kwargs):
        if mpu.is_pipeline_first_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
171
            inputs = fp32_to_float16(inputs, self.float16_convertor)
172
173
        outputs = self.module(*inputs, **kwargs)
        if mpu.is_pipeline_last_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
174
            outputs = float16_to_fp32(outputs)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        return outputs


    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.module.state_dict(destination, prefix, keep_vars)


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(destination, prefix,
                                                          keep_vars)


    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)