transformer.py 17.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math

import torch
21
22
23
24
25
26
try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
    # Try to use FusedLayerNorm from Apex - this will trigger an error.
    _ = FusedLayerNorm(8, eps=1e-5)

except Exception as e:
Boris Fomitchev's avatar
Boris Fomitchev committed
27
    print('WARNING: APEX is not available, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
28
    from torch.nn import LayerNorm
29

Mohammad's avatar
Mohammad committed
30
from megatron import get_args
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from megatron import mpu
from megatron.module import MegatronModule


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
    Transformer takes input of size [b, s, h] and returns a
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
               masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
"""

Neel Kant's avatar
Neel Kant committed
56

57
58
59
60
61
62
63
64
65
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

Mohammad's avatar
Mohammad committed
66
67
    def __init__(self, mlp_activation_func, init_method,
                 output_layer_init_method):
68
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
69
        args = get_args()
70
71
72

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
73
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
74
            4 * args.hidden_size,
75
            gather_output=False,
Mohammad's avatar
Mohammad committed
76
            init_method=init_method)
77

Mohammad's avatar
Mohammad committed
78
        self.activation_func = mlp_activation_func
79
80
81

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
Neel Kant's avatar
Neel Kant committed
82
            4 * args.hidden_size,
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            input_is_parallel=True,
Mohammad's avatar
Mohammad committed
85
            init_method=output_layer_init_method)
86

Mohammad's avatar
Mohammad committed
87
        self.dropout = torch.nn.Dropout(args.hidden_dropout)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    def forward(self, hidden_states):

        # [b, s, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)

        # [b, s, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        output = self.dropout(output)
        return output


class ParallelSelfAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
107

Mohammad's avatar
Mohammad committed
108
109
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number):
110
        super(ParallelSelfAttention, self).__init__()
Mohammad's avatar
Mohammad committed
111
        args = get_args()
Mohammad's avatar
Mohammad committed
112
        self.fp16 = args.fp16
113
114

        self.attention_mask_func = attention_mask_func
Mohammad's avatar
Mohammad committed
115
116
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
117
118
119
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
120
121
122

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
123
124
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
125
        self.hidden_size_per_attention_head = mpu.divide(
Mohammad's avatar
Mohammad committed
126
            args.hidden_size, args.num_attention_heads)
127
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
128
            args.num_attention_heads, world_size)
129
130
131

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
132
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
133
            3 * args.hidden_size,
134
135
            stride=3,
            gather_output=False,
Mohammad's avatar
Mohammad committed
136
            init_method=init_method)
137
138
139
140

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
141
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
142
143
144

        # Output.
        self.dense = mpu.RowParallelLinear(
Mohammad's avatar
Mohammad committed
145
146
            args.hidden_size,
            args.hidden_size,
147
            input_is_parallel=True,
Mohammad's avatar
Mohammad committed
148
149
            init_method=output_layer_init_method)
        self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
150
151
152
153
154
155

    def _transpose_for_scores(self, tensor):
        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
        size [b, np, s, hn].
        """
        new_tensor_shape = tensor.size()[:-1] + \
Neel Kant's avatar
Neel Kant committed
156
157
            (self.num_attention_heads_per_partition,
             self.hidden_size_per_attention_head)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        tensor = tensor.view(*new_tensor_shape)
        return tensor.permute(0, 2, 1, 3)

    def _get_query_key_value(self, hidden_states):
        """Get query, key, and value and transpose to
        get size [b, np, s, hn].
        """
        # Attention heads. [b, s, hp]
        mixed_x_layer = self.query_key_value(hidden_states)
        (mixed_query_layer,
         mixed_key_layer,
         mixed_value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)

        # Reshape and transpose [b, np, s, hn]
        query_layer = self._transpose_for_scores(mixed_query_layer)
        key_layer = self._transpose_for_scores(mixed_key_layer)
        value_layer = self._transpose_for_scores(mixed_value_layer)

        return query_layer, key_layer, value_layer

    def _get_unmasked_attention_scores(self, query_layer, key_layer):
        """Unmasked attention scores with size [b, np, s, s]."""
180
181
182
183
184
        coeff = 1
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
        norm_factor = math.sqrt(coeff *
                                math.sqrt(self.hidden_size_per_attention_head))
185
        # Raw attention scores. [b, np, s, s]
Neel Kant's avatar
Neel Kant committed
186
187
        return torch.matmul(query_layer / norm_factor,
                            key_layer.transpose(-1, -2) / norm_factor)
188
189
190
191
192
193

    def _get_attention_probs(self, attention_scores):
        """Attention probabilies with dropout. The output has
        the size [b, np, s, s].
        """
        # Attention probabilities. [b, np, s, s]
194
195
        if self.apply_query_key_layer_scaling:
            attention_scores = attention_scores * self.layer_number
196
        attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        return attention_probs

    def _get_attended_context(self, attention_probs, value_layer):
        """Final attended tesnor and transposed back to [b, s, hp]."""
        # Context layer.
        # [b, np, s, hn]
        context_layer = torch.matmul(attention_probs, value_layer)
        # [b, s, np, hn]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + \
Neel Kant's avatar
Neel Kant committed
212
            (self.hidden_size_per_partition,)
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        # [b, s, hp]
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer

    def _get_output(self, context_layer):
        """Output layer with dropout."""
        # Output. [b, s, h]
        output = self.dense(context_layer)
        output = self.output_dropout(output)

        return output

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [b, s, h]

        # Attention heads. [b, np, s, hn]
        query_layer, key_layer, value_layer = self._get_query_key_value(
            hidden_states)

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
                                   key_layer), dim=-2)
            value_layer = torch.cat((past_value.type_as(value_layer),
                                     value_layer), dim=-2)
        if get_key_value:
            present = (key_layer, value_layer)

        # Raw attention scores. [b, np, s, s]
        attention_scores = self._get_unmasked_attention_scores(
            query_layer, key_layer)

247
        # fp32 conversion.
Mohammad's avatar
Mohammad committed
248
        if self.fp16 and self.attention_softmax_in_fp32:
249
250
            attention_scores = attention_scores.float()

251
252
253
254
255
256
        # Apply attention mask. [b, np, s, s]
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
257
                        attention_scores.size(3) - 1,
258
259
260
261
262
263
264
265
266
267
268
269
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]
        attention_scores = self.attention_mask_func(attention_scores,
                                                    attention_mask)

        # Attention probabilities. [b, np, s, s]
        attention_probs = self._get_attention_probs(attention_scores)

270
        # fp16 conversion
Mohammad's avatar
Mohammad committed
271
        if self.fp16 and self.attention_softmax_in_fp32:
272
273
            attention_probs = attention_probs.half()

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        # Context layer. [b, s, hp]
        context_layer = self._get_attended_context(attention_probs, value_layer)

        # Output. [b, s, h]
        output = self._get_output(context_layer)

        if get_key_value:
            output = [output, present]

        return output


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
292

Mohammad's avatar
Mohammad committed
293
294
295
    def __init__(self, attention_mask_func, mlp_activation_func,
                 init_method, output_layer_init_method, layer_number):
        args = get_args()
296
297

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
298
        self.layer_number = layer_number
299
300

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
301
            = args.apply_residual_connection_post_layernorm
302
303
304

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
305
306
            args.hidden_size,
            eps=args.layernorm_epsilon)
307
308

        # Self attention.
Mohammad's avatar
Mohammad committed
309
310
311
        self.attention = ParallelSelfAttention(attention_mask_func, init_method,
                                               output_layer_init_method,
                                               layer_number)
312
313
314

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
315
316
            args.hidden_size,
            eps=args.layernorm_epsilon)
317
318

        # MLP
Mohammad's avatar
Mohammad committed
319
320
        self.mlp = ParallelMLP(mlp_activation_func, init_method,
                               output_layer_init_method)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [b, s, h]

        # Layer norm at the begining of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.attention(layernorm_output,
                                          attention_mask,
                                          layer_past=layer_past,
                                          get_key_value=get_key_value)
        if get_key_value:
            attention_output, presents = attention_output

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            layernorm_input = layernorm_output + attention_output
        else:
            layernorm_input = hidden_states + attention_output
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
        mlp_output = self.mlp(layernorm_output)
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            output = layernorm_output + mlp_output
        else:
            output = layernorm_input + mlp_output

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

Mohammad's avatar
Mohammad committed
361
362
    def __init__(self, attention_mask_func, mlp_activation_func,
                 init_method, output_layer_init_method):
363
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
364
        args = get_args()
365
366

        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
367
368
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
369

Mohammad's avatar
Mohammad committed
370
371
372
373
374
375
376
377
378
379
380
        # Number of layers:
        self.num_layers = args.num_layers
        self.num_unique_layers = args.num_unique_layers
        if self.num_unique_layers is None:
            self.num_unique_layers = self.num_layers
        assert self.num_layers % self.num_unique_layers == 0, \
            'number of layers should be divisible by number of unique layers'
        self.param_sharing_style = args.param_sharing_style

        # Transformer layers.
        def build_layer(layer_number):
381
            return ParallelTransformerLayer(
Mohammad's avatar
Mohammad committed
382
383
                attention_mask_func, mlp_activation_func,
                init_method, output_layer_init_method, layer_number)
384
        self.layers = torch.nn.ModuleList(
Mohammad's avatar
Mohammad committed
385
386
387
388
389
390
391
            [build_layer(i + 1) for i in range(self.num_unique_layers)])

        # Print layer ordering.
        if self.num_layers != self.num_unique_layers:
            if torch.distributed.get_rank() == 0:
                print('> will be using the following layer ordering:')
                for i in range(self.num_layers):
mohammad's avatar
mohammad committed
392
393
394
                    print('   layer id: {:3d} --> unique layer id: '
                          '{:3d}'.format(i, self._get_layer_index(i)),
                          flush=True)
395
396
397

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
398
399
            args.hidden_size,
            eps=args.layernorm_epsilon)
400

Mohammad's avatar
Mohammad committed
401
402
403
404
405
406
407
408
409
410
    def _get_layer_index(self, layer_number):
        if self.param_sharing_style == 'grouped':
            return layer_number % self.num_unique_layers
        if self.param_sharing_style == 'spaced':
            return layer_number // (self.num_layers // self.num_unique_layers) 
        assert False, 'should not be here'

    def _get_layer(self, layer_number):
        return self.layers[self._get_layer_index(layer_number)]

411
412
413
414
415
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
Mohammad's avatar
Mohammad committed
416
417
                for index in range(start, end):
                    layer = self._get_layer(index)
418
419
420
421
422
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

        l = 0
Mohammad's avatar
Mohammad committed
423
        while l < self.num_layers:
424
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
425
                custom(l, l + self.checkpoint_num_layers),
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                hidden_states, attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):

        # Checks
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask)
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
450
451
            for index in range(self.num_layers):
                layer = self._get_layer(index)
452
453
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
454
                    past = layer_past[index]
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)

        # Final layer norm.
        output = self.final_layernorm(hidden_states)
        if get_key_value:
            output = [output, presents]

        return output