chatglm.py 20.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
2
# Adapted from
3
4
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
import json
6
from typing import Iterable, Optional, Set, Tuple, Union
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
7
8
9
10

import torch
from torch import nn
from torch.nn import LayerNorm
zhuwenwen's avatar
zhuwenwen committed
11
import os
zhuwenwen's avatar
zhuwenwen committed
12
import re
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
13

14
from vllm.attention import Attention
15
from vllm.compilation.decorators import support_torch_compile
16
from vllm.config import CacheConfig, VllmConfig
17
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
18
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
19
from vllm.model_executor.layers.layernorm import RMSNorm
20
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
21
22
                                               QKVParallelLinear,
                                               RowParallelLinear)
23
from vllm.model_executor.layers.logits_processor import LogitsProcessor
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
from vllm.model_executor.layers.rotary_embedding import get_rope
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
    ParallelLMHead, VocabParallelEmbedding)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.sampling_metadata import SamplingMetadata
30
from vllm.sequence import IntermediateTensors
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
31
from vllm.transformers_utils.configs import ChatGLMConfig
32

33
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
34
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
35
36
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
37

zhuwenwen's avatar
zhuwenwen committed
38
from vllm import _custom_ops as ops
39
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
40

41

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
42
43
class GLMAttention(nn.Module):

44
45
    def __init__(
        self,
46
        config: ChatGLMConfig,
47
        cache_config: Optional[CacheConfig] = None,
48
        quant_config: Optional[QuantizationConfig] = None,
49
        prefix: str = "",
50
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
51
52
53
54
55
56
57
58
59
60
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
61
62
63
64
65
66
67
68
69
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
70
71
72
73
74
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

75
76
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
77
            self.head_dim,
78
79
80
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
81
            quant_config=quant_config,
82
            prefix=f"{prefix}.query_key_value",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
83
84
85
86
87
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
88
            quant_config=quant_config,
89
            prefix=f"{prefix}.dense",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
90
91
        )

92
93
94
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
95
96
97
        # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
        # which is equivalent to is_neox_style=True
        is_neox_style = not config.original_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
99
100
            self.head_dim,
            rotary_dim=self.head_dim // 2,
101
102
            max_position=max_positions,
            base=10000 * rope_ratio,
103
            is_neox_style=is_neox_style,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
104
        )
105
106
107
108
109
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
110
111
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
112
113
114
115
116
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
117
118
119
120
121
122
123

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
124
125
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
126
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        q, k = self.rotary_emb(position_ids, q, k)
128
        context_layer = self.attn(q, k, v)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
129
130
131
132
133
134
135
136
137
138
139
140
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

141
142
    def __init__(
        self,
143
        config: ChatGLMConfig,
144
        quant_config: Optional[QuantizationConfig] = None,
145
        prefix: str = "",
146
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
147
148
149
150
151
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
152
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
153
            config.hidden_size,
154
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
155
            bias=config.add_bias_linear,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.dense_h_to_4h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
158
159
160
161
162
163
164
165
166
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
167
            quant_config=quant_config,
168
            prefix=f"{prefix}.dense_4h_to_h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
189
        config: ChatGLMConfig,
190
        cache_config: Optional[CacheConfig] = None,
191
        quant_config: Optional[QuantizationConfig] = None,
192
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
193
194
195
196
197
198
199
200
201
202
203
204
205
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
206
207
208
209
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
210
211
212
213
214
215
216
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
217
        self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

258
259
    def __init__(
        self,
260
        config: ChatGLMConfig,
261
        cache_config: Optional[CacheConfig] = None,
262
        quant_config: Optional[QuantizationConfig] = None,
263
        prefix: str = "",
264
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
265
266
267
268
269
270
271
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
272
273
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
274
275
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
276
277
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
278
279
280
281
282
283
284

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

285
286
287
288
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
289
290
291
292
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
293
    ) -> Union[torch.Tensor, IntermediateTensors]:
294
295
296
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states=hidden_states,
                                  position_ids=position_ids)
297
298
299
300

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
301
        # Final layer norm.
302
        if self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
303
304
305
306
307
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


308
@support_torch_compile
309
310
311
312
313
class ChatGLMModel(nn.Module, SupportsQuant):
    packed_modules_mapping = {
        "linear_proj.merged_proj":
        ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"]
    }
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
314

315
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
316
317
        super().__init__()

318
319
320
321
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

322
323
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
324
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
325
                                                config.hidden_size,
326
327
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.embedding")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
328
329
330
331

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
332
333
334
335
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
336

337
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
338
                                           config.hidden_size,
339
340
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
341

342
343
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
344
345
346
347
348
349
350
351
352
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
353

354
355
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)
356

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
357
358
359
    def forward(
        self,
        input_ids: torch.Tensor,
360
361
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
362
        inputs_embeds: Optional[torch.Tensor] = None,
363
        **kwargs: object,
364
365
366
367
368
369
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
370
        else:
371
372
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
373
374
375

        # Run encoder.
        hidden_states = self.encoder(
376
            hidden_states=hidden_states,
377
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
378
        )
379

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
380
381
        return hidden_states

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
            ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "rotary_pos_emb.inv_freq" in name:
                    continue
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
            
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            # lay_qkv_words = ["self_attention.query_key_value.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
            
            # lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words)
            
zhuwenwen's avatar
zhuwenwen committed
434
435
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
                if "lm_head.weight" in layername and weight.shape[1] == 4096:
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0'
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:  
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
                        
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
                    
462
463
        return loaded_params

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
464

465
class ChatGLMBaseModel(nn.Module):
466
467
468
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".word_embeddings": ""}, )

469
470
471
472
473
474
475
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[ChatGLMModel] = ChatGLMModel,
    ) -> None:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
476
        super().__init__()
477
478
479
480
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
481
482
        self.config = config
        self.lora_config = lora_config
483
        self.multimodal_config = multimodal_config
484

485
        self.quant_config = quant_config
486
487
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
488
489
490
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
491
492
493
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
494
        self.lm_head = self.transformer.output_layer
495
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
496
497
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
498

499
500
501
502
503
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
504
        logits = self.logits_processor(self.lm_head, hidden_states,
505
506
507
                                       sampling_metadata)
        return logits

508
509
510
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
511
512


513
514
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsQuant):
515
516
517
518
519
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }

520
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
521
        config = vllm_config.model_config.hf_config
522
523
524
525
526
527
        if hasattr(config, "vision_config"):
            hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
528
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
529

530
        super().__init__(vllm_config=vllm_config, prefix=prefix)
531

532
    def forward(
533
534
        self,
        input_ids: torch.Tensor,
535
536
537
538
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
539
540
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
541
        return hidden_states