module.py 7.89 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Megatron Module"""

import torch
6
7
from torch.autograd import Variable
from torch.nn.parameter import Parameter
8

9
10
from megatron import get_args
from megatron import mpu
11
from megatron import core
12

13

14
15
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
17
18


mohammad's avatar
mohammad committed
19
20
21
22
23
24

def param_is_not_shared(param):
    return not hasattr(param, 'shared') or not param.shared



25
class MegatronModule(torch.nn.Module):
26
27
    """Megatron specific extensions of torch Module with support
    for pipelining."""
28

29
    def __init__(self, share_word_embeddings=True):
30
        super(MegatronModule, self).__init__()
31
32
        self.share_word_embeddings = share_word_embeddings

33

34
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
35
36
        """Use this function to override the state dict for
        saving checkpoints."""
37
        return self.state_dict(prefix=prefix, keep_vars=keep_vars)
38
39
40


    def word_embeddings_weight(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
        if self.pre_process:
42
            return self.language_model.embedding.word_embeddings.weight
43
        else:
44
            if not self.share_word_embeddings:
45
46
                raise Exception('word_embeddings_weight() called for last '
                                'stage, but share_word_embeddings is false')
47
48
            return self.word_embeddings.weight

49

50
51
    def initialize_word_embeddings(self, init_method_normal):
        args = get_args()
52
53
54
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')
55
56

        # This function just initializes the word embeddings in the final stage
57
58
        # when we are using pipeline parallelism. Nothing to do if we aren't
        # using pipeline parallelism.
Jared Casper's avatar
Jared Casper committed
59
60
        if args.pipeline_model_parallel_size == 1:
            return
61

62
        # Parameters are shared between the word embeddings layers, and the
63
64
65
66
67
68
69
70
71
72
73
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
74
        if mpu.is_pipeline_last_stage() and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
75
                not self.pre_process:
76
77
78
79
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
80
            self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
81
                args.padded_vocab_size, args.hidden_size,
82
83
84
85
                init_method=init_method_normal(args.init_method_std),
                params_dtype=args.params_dtype,
                use_cpu_initialization=args.use_cpu_initialization,
                perform_initialization=args.perform_initialization)
86
87
88
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True

89
90
91
        # Zero out initial weights for decoder embedding.
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
92
                self.pre_process:
93
94
            self.language_model.embedding.zero_parameters()

95
96
97
98
99
100
101
102
103
104
        if not torch.distributed.is_initialized():
            if not getattr(MegatronModule, "embedding_warning_printed", False):
                print("WARNING! Distributed processes aren't initialized, so "
                      "word embeddings in the last layer are not initialized. "
                      "If you are just manipulating a model this is fine, but "
                      "this needs to be handled manually. If you are training "
                      "something is definitely wrong.")
                MegatronModule.embedding_warning_printed = True
            return

105
106
        # Ensure that first and last stages have the same initial parameter
        # values.
107
108
109
        if mpu.is_rank_in_embedding_group():
            torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                         group=mpu.get_embedding_group())
110
111
112
113
114
115
116
117
118
119
120

        # Ensure that encoder(first stage) and decoder(split stage) position
        # embeddings have the same initial parameter values
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if mpu.is_rank_in_position_embedding_group() and \
                args.pipeline_model_parallel_split_rank is not None:
            # TODO: Support tokentype embedding.
            self.language_model.embedding.cuda()
            position_embeddings = self.language_model.embedding.position_embeddings
            torch.distributed.all_reduce(position_embeddings.weight.data,
                                         group=mpu.get_position_embedding_group())
121

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
122

123
124
125
126
127
128
129
130
131
132
133
def conversion_helper(val, conversion):
    """Apply conversion to val. Recursively apply conversion if `val`
    #is a nested tuple/list structure."""
    if not isinstance(val, (tuple, list)):
        return conversion(val)
    rtn = [conversion_helper(v, conversion) for v in val]
    if isinstance(val, tuple):
        rtn = tuple(rtn)
    return rtn


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
134
135
def fp32_to_float16(val, float16_convertor):
    """Convert fp32 `val` to fp16/bf16"""
136
137
138
139
140
    def half_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, _FLOAT_TYPES):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
141
            val = float16_convertor(val)
142
143
144
145
        return val
    return conversion_helper(val, half_conversion)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146
147
def float16_to_fp32(val):
    """Convert fp16/bf16 `val` to fp32"""
148
149
150
151
    def float_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
152
        if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
153
154
155
156
157
158
            val = val.float()
        return val
    return conversion_helper(val, float_conversion)



Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class Float16Module(MegatronModule):

    def __init__(self, module, args):
        super(Float16Module, self).__init__()

        if args.fp16:
            self.add_module('module', module.half())
            def float16_convertor(val):
                return val.half()
        elif args.bf16:
            self.add_module('module', module.bfloat16())
            def float16_convertor(val):
                return val.bfloat16()
        else:
            raise Exception('should not be here')
174

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175
        self.float16_convertor = float16_convertor
176
177


mshoeybi's avatar
working  
mshoeybi committed
178
179
180
181
    def set_input_tensor(self, input_tensor):
        return self.module.set_input_tensor(input_tensor)


182
183
    def forward(self, *inputs, **kwargs):
        if mpu.is_pipeline_first_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
184
            inputs = fp32_to_float16(inputs, self.float16_convertor)
185
186
        outputs = self.module(*inputs, **kwargs)
        if mpu.is_pipeline_last_stage():
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
187
            outputs = float16_to_fp32(outputs)
188
189
190
        return outputs


191
192
    def state_dict(self, prefix='', keep_vars=False):
        return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
193
194


195
196
197
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(prefix=prefix,
                                                          keep_vars=keep_vars)
198
199
200
201


    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)