module.py 7.85 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Megatron Module"""

import torch
6
7
from torch.autograd import Variable
from torch.nn.parameter import Parameter
8

9
from megatron import get_args
10
from megatron.core import mpu, tensor_parallel
11

12

13
14
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
15
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
16
17


mohammad's avatar
mohammad committed
18
19
20
21
22
23

def param_is_not_shared(param):
    return not hasattr(param, 'shared') or not param.shared



24
class MegatronModule(torch.nn.Module):
25
26
    """Megatron specific extensions of torch Module with support
    for pipelining."""
27

liangjing's avatar
v1  
liangjing committed
28
    def __init__(self, config=None, share_embeddings_and_output_weights=True):
29
        super(MegatronModule, self).__init__()
liangjing's avatar
v1  
liangjing committed
30
31
        self.config = config
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
32

33

34
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
35
36
        """Use this function to override the state dict for
        saving checkpoints."""
37
        return self.state_dict(prefix=prefix, keep_vars=keep_vars)
38
39


liangjing's avatar
v1  
liangjing committed
40
    def shared_embedding_or_output_weight(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
        if self.pre_process:
42
            return self.language_model.embedding.word_embeddings.weight
43
        else:
liangjing's avatar
v1  
liangjing committed
44
45
46
            if not self.share_embeddings_and_output_weights:
                raise Exception('shared_embedding_or_output_weight() called for last '
                                'stage, but share_embeddings_and_output_weights is false')
47
48
            return self.word_embeddings.weight

49

liangjing's avatar
v1  
liangjing committed
50
    def initialize_word_embeddings(self):
51
        args = get_args()
liangjing's avatar
v1  
liangjing committed
52
        if not self.share_embeddings_and_output_weights:
53
            raise Exception('initialize_word_embeddings() was called but '
liangjing's avatar
v1  
liangjing committed
54
                            'share_embeddings_and_output_weights is false')
55
56

        # This function just initializes the word embeddings in the final stage
57
58
        # when we are using pipeline parallelism. Nothing to do if we aren't
        # using pipeline parallelism.
Jared Casper's avatar
Jared Casper committed
59
60
        if args.pipeline_model_parallel_size == 1:
            return
61

62
        # Parameters are shared between the word embeddings layers, and the
63
64
65
66
67
68
69
70
71
72
73
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
74
        if mpu.is_pipeline_last_stage() and not self.pre_process:
75
76
77
78
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
79
            self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
liangjing's avatar
v1  
liangjing committed
80
81
                args.padded_vocab_size, self.config.hidden_size,
                config=self.config, init_method=self.config.init_method)
82
83
84
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True

85
86
87
        # Zero out initial weights for decoder embedding.
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
88
                self.pre_process:
89
90
            self.language_model.embedding.zero_parameters()

91
92
93
94
95
96
97
98
99
100
        if not torch.distributed.is_initialized():
            if not getattr(MegatronModule, "embedding_warning_printed", False):
                print("WARNING! Distributed processes aren't initialized, so "
                      "word embeddings in the last layer are not initialized. "
                      "If you are just manipulating a model this is fine, but "
                      "this needs to be handled manually. If you are training "
                      "something is definitely wrong.")
                MegatronModule.embedding_warning_printed = True
            return

101
102
        # Ensure that first and last stages have the same initial parameter
        # values.
103
        if mpu.is_rank_in_embedding_group():
liangjing's avatar
v1  
liangjing committed
104
            torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data,
105
                                         group=mpu.get_embedding_group())
106
107
108
109
110
111
112
113
114
115
116

        # Ensure that encoder(first stage) and decoder(split stage) position
        # embeddings have the same initial parameter values
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if mpu.is_rank_in_position_embedding_group() and \
                args.pipeline_model_parallel_split_rank is not None:
            # TODO: Support tokentype embedding.
            self.language_model.embedding.cuda()
            position_embeddings = self.language_model.embedding.position_embeddings
            torch.distributed.all_reduce(position_embeddings.weight.data,
                                         group=mpu.get_position_embedding_group())
117

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
118

119
120
121
122
123
124
125
126
127
128
129
def conversion_helper(val, conversion):
    """Apply conversion to val. Recursively apply conversion if `val`
    #is a nested tuple/list structure."""
    if not isinstance(val, (tuple, list)):
        return conversion(val)
    rtn = [conversion_helper(v, conversion) for v in val]
    if isinstance(val, tuple):
        rtn = tuple(rtn)
    return rtn


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
130
131
def fp32_to_float16(val, float16_convertor):
    """Convert fp32 `val` to fp16/bf16"""
132
133
134
135
136
    def half_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, _FLOAT_TYPES):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
137
            val = float16_convertor(val)
138
139
140
141
        return val
    return conversion_helper(val, half_conversion)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
142
143
def float16_to_fp32(val):
    """Convert fp16/bf16 `val` to fp32"""
144
145
146
147
    def float_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
148
        if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
149
150
151
152
153
154
            val = val.float()
        return val
    return conversion_helper(val, float_conversion)



Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class Float16Module(MegatronModule):

    def __init__(self, module, args):
        super(Float16Module, self).__init__()

        if args.fp16:
            self.add_module('module', module.half())
            def float16_convertor(val):
                return val.half()
        elif args.bf16:
            self.add_module('module', module.bfloat16())
            def float16_convertor(val):
                return val.bfloat16()
        else:
            raise Exception('should not be here')
170

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
171
        self.float16_convertor = float16_convertor
172
173


mshoeybi's avatar
working  
mshoeybi committed
174
175
176
177
    def set_input_tensor(self, input_tensor):
        return self.module.set_input_tensor(input_tensor)


178
179
    def forward(self, *inputs, **kwargs):
        if mpu.is_pipeline_first_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
180
            inputs = fp32_to_float16(inputs, self.float16_convertor)
181
182
        outputs = self.module(*inputs, **kwargs)
        if mpu.is_pipeline_last_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
183
            outputs = float16_to_fp32(outputs)
184
185
186
        return outputs


187
188
    def state_dict(self, prefix='', keep_vars=False):
        return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
189
190


191
192
193
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(prefix=prefix,
                                                          keep_vars=keep_vars)
194
195
196
197


    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)