transformer.py 21.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math
import torch
20
import torch.nn.functional as F
21

Mohammad's avatar
Mohammad committed
22
from megatron import get_args
23
from megatron import mpu
24
from .module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
from megatron.checkpointing import get_checkpoint_version
26
from megatron.model import import_layernorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30

31
32
33
34
35
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
36
37
38
39
40
41
42
43
44
45
46

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
47
    Transformer takes input of size [s, b, h] and returns a
48
49
50
51
52
53
54
55
56
57
58
59
60
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

61
    def __init__(self, init_method, output_layer_init_method):
62
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
63
        args = get_args()
64
65
66

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
67
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
68
            4 * args.hidden_size,
69
            gather_output=False,
70
71
            init_method=init_method,
            skip_bias_add=True)
72

73
74
75
76
77
78
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
79
80
81

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
Neel Kant's avatar
Neel Kant committed
82
            4 * args.hidden_size,
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            input_is_parallel=True,
85
86
            init_method=output_layer_init_method,
            skip_bias_add=True)
87

88
89
90

    def forward(self, hidden_states):

91
92
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
93

94
95
96
97
98
99
100
101
102
103
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
104
105
106
107
108
109
110
111


class ParallelSelfAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
112

113
    def __init__(self, init_method, output_layer_init_method, layer_number):
114
        super(ParallelSelfAttention, self).__init__()
Mohammad's avatar
Mohammad committed
115
        args = get_args()
Mohammad's avatar
Mohammad committed
116
        self.fp16 = args.fp16
117

Mohammad's avatar
Mohammad committed
118
119
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
120
121
122
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
123
124

        # Per attention head and per partition values.
125
        world_size = mpu.get_tensor_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
126
127
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
128
        self.hidden_size_per_attention_head = mpu.divide(
Mohammad's avatar
Mohammad committed
129
            args.hidden_size, args.num_attention_heads)
130
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
131
            args.num_attention_heads, world_size)
132
133
134

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
135
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
136
            3 * args.hidden_size,
137
            gather_output=False,
Mohammad's avatar
Mohammad committed
138
            init_method=init_method)
139

140
141
142
143
144
145
146
147
148
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
            args.scaled_upper_triang_masked_softmax_fusion,
149
            args.scaled_masked_softmax_fusion,
150
            attention_mask_func,
151
152
153
            self.attention_softmax_in_fp32,
            coeff)

154
155
156
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
157
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
158
159
160

        # Output.
        self.dense = mpu.RowParallelLinear(
Mohammad's avatar
Mohammad committed
161
162
            args.hidden_size,
            args.hidden_size,
163
            input_is_parallel=True,
164
165
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
166

Vijay Korthikanti's avatar
Vijay Korthikanti committed
167
    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
168
        input_shape = mixed_layer.size()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
169
170
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
171
172
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
174
            -->(view) [s, b, np * num_splits * hn] """

175
176
177
            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
178

179
180
181
            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
182
            """[s, b, np * hn * num_splits] 
183
184
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
185
186
            -->(view) [s, b, np * num_splits * hn] """

187
188
189
190
191
192
            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
193
194
195
        mixed_layer = mixed_layer.view(*input_shape)
        
        return mixed_layer
196

197
198
    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
199
        # hidden_states: [sq, b, h]
200

201
202
203
        # =====================
        # Query, Key, and Value
        # =====================
204

Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
206
        mixed_x_layer, _ = self.query_key_value(hidden_states)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
207

Vijay Korthikanti's avatar
Vijay Korthikanti committed
208
        checkpoint_version = get_checkpoint_version()
209
        if checkpoint_version is not None:
210
211
212
213
214
215
            if checkpoint_version == 0:
                # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
            elif checkpoint_version == 1.0:
                # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
216

Vijay Korthikanti's avatar
Vijay Korthikanti committed
217
        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
218
219
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
220
             3 * self.hidden_size_per_attention_head)
221
222
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
224
225
226
        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        (query_layer,
         key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
227

228
229
230
        # ==================================
        # Adjust key and value for inference
        # ==================================
231
232
233
234

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
235
                                   key_layer), dim=0)
236
            value_layer = torch.cat((past_value.type_as(value_layer),
237
                                     value_layer), dim=0)
238
239
240
        if get_key_value:
            present = (key_layer, value_layer)

241
242
243
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
244

245
        # [b, np, sq, sk]
246
247
248
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
249
                       key_layer.size(0))
250

251
        # [sq, b, np, hn] -> [sq, b * np, hn]
252
253
254
255
256
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

257
        # preallocting result tensor: [b * np, sq, sk]
258
        matmul_result = torch.empty(
259
260
            output_size[0]*output_size[1],
            output_size[2],
261
            output_size[3],
262
            dtype=query_layer.dtype,
263
264
            device=torch.cuda.current_device())

265
        # Raw attention scores. [b * np, sq, sk]
266
267
        matmul_result = torch.baddbmm(
            matmul_result,
268
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
269
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
270
271
            beta=0.0, alpha=(1.0/self.norm_factor))

272
        # change view to [b, np, sq, sk]
273
274
275
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
276
        # Update attention mask for inference. [b, np, sq, sk]
277
        # ==================================================
278

279
280
281
282
283
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
284
                        attention_scores.size(3) - 1,
285
286
287
288
289
290
291
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

292
293
294
        # ===========================
        # Attention probs and dropout
        # ===========================
295

296
        # attention scores and attention mask [b, np, sq, sk]
297
298
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
299

300
301
302
303
304
305
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
306
        # Context layer. [sq, b, hp]
307
308
        # =========================

309
310
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
311

312
        # context layer shape: [b, np, sq, hn]
313
314
315
316
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
317

318
        # change view [sk, b * np, hn]
319
        value_layer = value_layer.view(value_layer.size(0),
320
                                       output_size[0] * output_size[1], -1)
321

322
        # change view [b * np, sq, sk]
323
324
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
325

326
        # matmul: [b * np, sq, hn]
327
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
328

329
        # change view [b, np, sq, hn]
330
331
        context_layer = context_layer.view(*output_size)

332
        # [b, np, sq, hn] --> [sq, b, np, hn]
333
334
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

335
        # [sq, b, np, hn] --> [sq, b, hp]
336
337
338
339
340
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
341
        # Output. [sq, b, h]
342
343
344
        # =================

        output, bias = self.dense(context_layer)
345
346
347
348

        if get_key_value:
            output = [output, present]

349
350
351
        return output, bias


352
def bias_dropout_add(x, bias, residual, prob, training):
353
354
355
356
357
358
359
360
361
362
363
364
365
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
366
def bias_dropout_add_fused_train(x, bias, residual, prob):
367
368
369
370
371
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
372
def bias_dropout_add_fused_inference(x, bias, residual, prob):
373
374
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
375
376
377
378
379
380
381
382


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
383

384
    def __init__(self, init_method, output_layer_init_method, layer_number):
Mohammad's avatar
Mohammad committed
385
        args = get_args()
386
387

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
388
        self.layer_number = layer_number
389
390

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
391
            = args.apply_residual_connection_post_layernorm
392
393

        # Layernorm on the input data.
394
        LayerNorm = import_layernorm(args.fp32_residual_connection)
395
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
396
397
            args.hidden_size,
            eps=args.layernorm_epsilon)
398
399

        # Self attention.
400
        self.attention = ParallelSelfAttention(init_method,
Mohammad's avatar
Mohammad committed
401
402
                                               output_layer_init_method,
                                               layer_number)
403
404
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
405
406
407

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
408
409
            args.hidden_size,
            eps=args.layernorm_epsilon)
410
411

        # MLP
412
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
413
                               output_layer_init_method)
414
415
416
417
418
419
420
421

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [b, s, h]

        # Layer norm at the begining of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
422
423
424
425
426
427
        attention_output, attention_bias = \
            self.attention(layernorm_output,
                           attention_mask,
                           layer_past=layer_past,
                           get_key_value=get_key_value)

428
429
        if get_key_value:
            attention_output, presents = attention_output
430

431
432
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
433
434
435
436
            residual = layernorm_output
        else:
            residual = hidden_states

437
438
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
439
440
441
442
443
444
445
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
446
        else:
447
448
            bias_dropout_add_func = get_bias_dropout_add(self.training)

449
        # re-enable torch grad to enable fused optimization.
450
451
452
453
454
455
456
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

457
458
459
460
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
461
462
        mlp_output, mlp_bias = self.mlp(layernorm_output)
        
463
464
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
465
            residual = layernorm_output
466
        else:
467
468
            residual = layernorm_input

469
        # re-enable torch grad to enable fused optimization.
470
471
472
473
474
475
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
476
477
478
479
480
481
482
483
484
485

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

486
    def __init__(self, init_method, output_layer_init_method):
487
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
488
        args = get_args()
489

490
491
        self.fp32_residual_connection = args.fp32_residual_connection

492
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
493
494
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
495

496
        # Number of layers.
497
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
498
            'num_layers must be divisible by pipeline_model_parallel_size'
499
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
500
501
502

        # Transformer layers.
        def build_layer(layer_number):
503
            return ParallelTransformerLayer(
504
                init_method, output_layer_init_method, layer_number)
505
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
506
        self.layers = torch.nn.ModuleList(
507
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
508

509
        if mpu.is_pipeline_last_stage():
510
            # Final layer norm before output.
511
            LayerNorm = import_layernorm(args.fp32_residual_connection)
512
513
514
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
515

Mohammad's avatar
Mohammad committed
516
    def _get_layer(self, layer_number):
517
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
518

519
520
521
522
523
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
Mohammad's avatar
Mohammad committed
524
525
                for index in range(start, end):
                    layer = self._get_layer(index)
526
527
528
529
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

530
531
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
532
        l = 0
Mohammad's avatar
Mohammad committed
533
        while l < self.num_layers:
534
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
535
                custom(l, l + self.checkpoint_num_layers),
536
537
538
539
540
541
542
543
                hidden_states, attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):

544
        # Checks.
545
546
547
548
549
550
551
552
553
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

554
555
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
556
            # If the input flag for fp32 residual connection is set, convert for float.
557
558
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
559
            # Otherwise, leave it as is.
560
561
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
562

563
564
565
566
567
568
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask)
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
569
570
            for index in range(self.num_layers):
                layer = self._get_layer(index)
571
572
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
573
                    past = layer_past[index]
574
575
576
577
578
579
580
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
581

582
        # Final layer norm.
583
        if mpu.is_pipeline_last_stage():
584
585
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
586
587
588
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
589
590
591
592
        if get_key_value:
            output = [output, presents]

        return output