hunyuan_v1.py 38.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# coding=utf-8
# Copyright 2024 The HunYuan team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
30
31
32
33
34
35
36

import regex as re
import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
43
from vllm.model_executor.layers.activation import SiluAndMul
44
from vllm.model_executor.layers.attention import Attention
45
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
46
from vllm.model_executor.layers.layernorm import RMSNorm
47
48
49
50
51
52
53
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
54
from vllm.model_executor.layers.logits_processor import LogitsProcessor
55
from vllm.model_executor.layers.quantization import QuantizationConfig
56
57
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
from vllm.model_executor.model_loader.weight_utils import (
62
63
64
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
65
from vllm.sequence import IntermediateTensors
66
from vllm.v1.attention.backend import AttentionType
67

68
69
70
71
72
73
74
75
from .interfaces import (
    EagleModelMixin,
    MixtureOfExperts,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
76
77
78
79
80
81
82
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_layers,
    maybe_prefix,
)
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
def _is_moe(config: PretrainedConfig) -> bool:
    num_experts = getattr(config, "num_experts", None)
    if isinstance(num_experts, int):
        return num_experts > 1
    if isinstance(num_experts, list) and num_experts:
        # Ensure all elements are integers before calling max.
        if all(isinstance(e, int) for e in num_experts):
            return max(num_experts) > 1
        else:
            return False
    return False


98
99
100
101
102
103
104
105
106
107
108
109
def _get_cla_factor(config: PretrainedConfig) -> int:
    if not getattr(config, "use_cla", False):
        return 1
    return getattr(config, "cla_share_factor", 1)


class HunYuanMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
110
        quant_config: QuantizationConfig | None = None,
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        bias: bool = False,
        prefix: str = "",
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            reduce_results=reduce_results,
        )
        if hidden_act != "silu":
132
133
134
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class HunYuanAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
152
        quant_config: QuantizationConfig | None = None,
153
        bias: bool = False,
154
        cache_config: CacheConfig | None = None,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        prefix: str = "",
        layer_id: int = -1,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
174
175

        if hasattr(config, "head_dim") and config.head_dim:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            self.head_dim = config.head_dim
        elif hasattr(config, "attention_head_dim"):
            self.head_dim = config.attention_head_dim
        else:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.layer_id = layer_id

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
209
            rope_parameters=config.rope_parameters,
210
211
212
213
214
215
216
217
218
219
220
221
222
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

        if self.use_qk_norm:
223
224
            self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
225
226
227
228
229

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
230
        kv_states: tuple[torch.Tensor] | None = None,
231
232
233
234
235
236
237
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        ori_k = k
        if self.use_qk_norm:
            q = self.query_layernorm(
238
239
                q.view(-1, self.num_heads, self.head_dim).contiguous()
            )
240
            k = self.key_layernorm(
241
242
                k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
            )
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        attn_output = self.attn(q, k, v)
        # For o_proj
        attn_output = attn_output.view(q.shape[0], -1)
        output, _ = self.o_proj(attn_output)
        return output, (ori_k, v)


class HunYuanCrossAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
259
        quant_config: QuantizationConfig | None = None,
260
        bias: bool = False,
261
        cache_config: CacheConfig | None = None,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        prefix: str = "",
        layer_id: int = -1,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        if hasattr(config, "head_dim"):
            self.head_dim = config.head_dim
        elif hasattr(config, "attention_head_dim"):
            self.head_dim = config.attention_head_dim
        else:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.layer_id = layer_id

        self.q_proj = ColumnParallelLinear(
            hidden_size,
            hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
314
            rope_parameters=config.rope_parameters,
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            attn_type=AttentionType.ENCODER_DECODER,
        )

        if self.use_qk_norm:
329
330
            self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
331
332
333
334
335

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
336
        kv_states: tuple[torch.Tensor] | None = None,
337
338
339
340
341
342
343
344
345
    ) -> torch.Tensor:
        assert kv_states is not None
        ori_k, v = kv_states  # use last layer kv,
        k = ori_k
        q, _ = self.q_proj(hidden_states)
        k_tmp = torch.empty_like(k)  # Todo: reduant rotary embedding
        q, _ = self.rotary_emb(positions, q, k_tmp)
        if self.use_qk_norm:
            q = self.query_layernorm(
346
347
                q.view(-1, self.num_heads, self.head_dim).contiguous()
            )
348
            k = self.key_layernorm(
349
350
                k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
            )
351
352
353
354
355
356
357
358
359
360
361
362

        attn_output = self.attn(q, k, v)
        # For o_proj
        attn_output = attn_output.view(q.shape[0], -1)
        output, _ = self.o_proj(attn_output)
        return output, (ori_k, v)


class HunYuanSparseMoeBlock(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
363
        quant_config: QuantizationConfig | None = None,
364
365
        layer_id: int = -1,
        prefix: str = "",
366
        enable_eplb: bool = False,
367
368
369
370
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()

371
        self.ep_group = get_ep_group().device_group
372
        self.ep_rank = get_ep_group().rank_in_group
373
374
375
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

376
377
378
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
379
380
                f"the number of experts {config.num_experts}."
            )
381
382
383
384
385
386
387
388
389
390
391
392

        # Get layer_id topk if config.moe_topk is a list
        if isinstance(config.moe_topk, list):
            assert layer_id >= 0
            assert len(config.moe_topk) > layer_id
            top_k = config.moe_topk[layer_id]
        else:
            top_k = config.moe_topk

        # If it is moe, moe_intermediate_size is preferred
        intermediate_size = config.intermediate_size
        if config.moe_intermediate_size is not None:
393
394
395
396
397
            intermediate_size = (
                config.moe_intermediate_size
                if isinstance(config.moe_intermediate_size, int)
                else config.moe_intermediate_size[layer_id]
            )
398

399
400
401
402
403
404
405
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
        self.enable_eplb = enable_eplb

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
406
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
407
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size
408
409
410
411
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
412

413
414
415
416
417
418
419
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        if config.use_mixed_mlp_moe > 0:
            # Get layer_id num_shared_expert if config.num_shared_expert is
            # a list.
            if isinstance(config.num_shared_expert, list):
                assert layer_id >= 0
                assert len(config.num_shared_expert) > layer_id
                num_shared_expert = config.num_shared_expert[layer_id]
            else:
                num_shared_expert = config.num_shared_expert

            self.shared_mlp = HunYuanMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size * num_shared_expert,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
436
                prefix=f"{prefix}.shared_mlp",
437
438
439
440
            )
        else:
            self.shared_mlp = None

441
442
443
444
445
446
447
448
449
450
451
452
453
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_mlp,
            num_experts=self.n_routed_experts,
            top_k=top_k,
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size,
            renormalize=top_k > 1,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
        )

454
455
456
457
458
459
460
461
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
462
463
464
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
465
466
467
468
469
470
471
472

        return final_hidden_states.view(orig_shape)


class HunYuanDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
473
474
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
475
476
        prefix: str = "",
        layer_id: int = -1,
477
        enable_eplb: bool = False,
478
479
480
481
482
    ) -> None:
        super().__init__()
        assert layer_id >= 0
        self.layer_id = layer_id
        self.hidden_size = config.hidden_size
483
484
485
486
487
488
        self.intermediate_size = (
            config.intermediate_size
            if isinstance(config.intermediate_size, int)
            else config.intermediate_size[layer_id]
        )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
489
        attention_bias = getattr(config, "attention_bias", False) or getattr(
490
491
            config, "bias", False
        )
492
        cla_factor = _get_cla_factor(config)
493
494
495
496
497
        attention_type = (
            AttentionType.ENCODER_DECODER
            if layer_id >= 0 and layer_id % cla_factor != 0
            else AttentionType.DECODER
        )
498
499
500
501
502
        if attention_type == AttentionType.DECODER:
            self.self_attn = HunYuanAttention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
503
504
505
                num_kv_heads=getattr(
                    config, "num_key_value_heads", config.num_attention_heads
                ),
506
507
508
509
510
511
512
513
514
515
516
517
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                bias=attention_bias,
                cache_config=cache_config,
                prefix=f"{prefix}.self_attn",
                layer_id=layer_id,
            )
        elif attention_type == AttentionType.ENCODER_DECODER:
            self.self_attn = HunYuanCrossAttention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
518
519
520
                num_kv_heads=getattr(
                    config, "num_key_value_heads", config.num_attention_heads
                ),
521
522
523
524
525
526
527
528
529
530
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                bias=attention_bias,
                cache_config=cache_config,
                prefix=f"{prefix}.self_attn",
                layer_id=layer_id,
            )
        else:
            raise RuntimeError(f"Unsupported attention type: {attention_type}")

531
532
533
534
535
536
        if _is_moe(config):
            self.mlp = HunYuanSparseMoeBlock(
                config=config,
                quant_config=quant_config,
                layer_id=layer_id,
                prefix=f"{prefix}.mlp",
537
                enable_eplb=enable_eplb,
538
539
540
541
542
543
544
545
546
547
548
            )
        else:
            self.mlp = HunYuanMLP(
                hidden_size=self.hidden_size,
                intermediate_size=self.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                bias=getattr(config, "mlp_bias", False),
                prefix=f"{prefix}.mlp",
            )

549
550
551
552
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
553
554
555
556
557

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
558
559
        residual: torch.Tensor | None,
        kv_states: tuple[torch.Tensor] | None = None,
560
561
562
563
564
565
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
566
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
567
568
569
570
571
572
573
        hidden_states, ori_kv_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_states=kv_states,
        )

        # Fully Connected
574
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
575
576
577
578
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual, ori_kv_states


579
580
581
582
583
584
585
586
587
588
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
    }
)
589
class HunYuanModel(nn.Module, EagleModelMixin):
590
591
592
593
594
595
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
596

597
598
599
        eplb_config = vllm_config.parallel_config.eplb_config
        enable_eplb = vllm_config.parallel_config.enable_eplb
        self.num_redundant_experts = eplb_config.num_redundant_experts
600
601
602

        self.config = config
        self.quant_config = quant_config
603
604
605

        self.vocab_size = config.vocab_size

606
607
608
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: HunYuanDecoderLayer(
                config=config,
                layer_id=int(prefix.split(".")[-1]),
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
624
                enable_eplb=enable_eplb,
625
626
627
628
629
630
631
632
            ),
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

633
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
634
635
636
637
        return self.embed_tokens(input_ids)

    def forward(
        self,
638
        input_ids: torch.Tensor | None,
639
        positions: torch.Tensor,
640
641
642
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
643
644
645
646
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
647
                hidden_states = self.embed_input_ids(input_ids)
648
649
650
651
652
653
654
655
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        cla_factor = _get_cla_factor(self.config)
        prev_kv_states = None
656
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
657
658
659
        for i, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer)
        ):
660
661
662
663
664
665
666
            hidden_states, residual, kv_states = layer(
                positions,
                hidden_states,
                residual,
                prev_kv_states,
            )

667
            if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
668
669
670
671
                prev_kv_states = kv_states
            else:
                prev_kv_states = None

672
673
674
675
            self._maybe_add_hidden_state(
                aux_hidden_states, i + 1, hidden_states, residual
            )

676
        if not get_pp_group().is_last_rank:
677
678
679
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
680
681

        hidden_states, _ = self.norm(hidden_states, residual)
682
683
684

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
685
686
687
688
        return hidden_states

    def _split_qkv_weight(self, qkv: torch.Tensor):
        num_attention_heads = self.config.num_attention_heads
689
690
691
        num_kv_heads = getattr(
            self.config, "num_key_value_heads", self.config.num_attention_heads
        )
692
693
694
695
696
697
698
699
700
701
        num_key_value_groups = num_attention_heads // num_kv_heads
        hidden_size = self.config.hidden_size

        if hasattr(self.config, "head_dim"):
            attention_head_dim = self.config.head_dim
        elif hasattr(self.config, "attention_head_dim"):
            attention_head_dim = self.config.attention_head_dim
        else:
            attention_head_dim = self.config.hidden_size // num_attention_heads

702
703
704
        qkv = qkv.reshape(
            num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
        )
705
706
707
708
709
710
        q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
        q = q.reshape(-1, hidden_size)
        k = k.reshape(-1, hidden_size)
        v = v.reshape(-1, hidden_size)
        return torch.concat((q, k, v))

711
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
712
713
714
        if _is_moe(self.config):
            # Params for weights, fp8 weight scales, fp8 activation scales
            # (param_name, weight_name, expert_id, shard_id)
715
            return SharedFusedMoE.make_expert_params_mapping(
716
                self,
717
718
719
720
                ckpt_gate_proj_name="gate_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="up_proj",
                num_experts=self.config.num_experts,
721
                num_redundant_experts=self.num_redundant_experts,
722
723
724
            )
        else:
            return []
725

726
727
728
729
730
731
732
733
734
735
736
737
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        cla_factor = _get_cla_factor(self.config)
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        num_attention_heads = self.config.num_attention_heads
738
739
740
        num_kv_heads = getattr(
            self.config, "num_key_value_heads", self.config.num_attention_heads
        )
741
742
743
744
745
746
        split_params_mapping = [
            (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
            (
                ".qkv_proj",
                ".qkv_proj",
                num_attention_heads + num_kv_heads * 2,
747
                [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
748
749
750
751
752
                self._split_qkv_weight,
            ),
        ]

        params_dict = dict(self.named_parameters())
753
754
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
755
756
757
758
759
760
761
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "gate_proj_bias" in name:
                name = name.replace("gate_proj_bias", "gate_proj.bias")
            if "up_proj_bias" in name:
                name = name.replace("up_proj_bias", "up_proj.bias")
762
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
763
764
765
766
767
768
769
770
771
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
            if self.quant_config is not None and (
772
773
                scale_name := self.quant_config.get_cache_scale(name)
            ):
774
775
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
776
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                continue

            is_found = False
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if "mlp.experts" in name:
                    continue
                # cross layer only have q_proj, skip qkv pack
                if weight_name == ".q_proj":
                    match = re.search(r"layers\.\d+", name)
                    if match:
                        layer_id = int(match.group(0).split(".")[-1])
                        if cla_factor > 1 and layer_id % cla_factor != 0:
                            continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
805
                loaded_params.add(name)
806
807
808
809
810
811
                is_found = True
                break
            if is_found:
                continue

            for (
812
813
814
815
816
                param_name,
                weight_name,
                den,
                split_param,
                func,
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            ) in split_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                assert loaded_weight.shape[0] % den == 0
                units = loaded_weight.shape[0] // den

                param = params_dict[name]
                weight_loader = param.weight_loader
                offset = 0
                for shard_id, num in split_param:
                    new_offset = offset + num * units
                    if func:
837
838
839
                        weight_loader(
                            param, func(loaded_weight)[offset:new_offset], shard_id
                        )
840
                    else:
841
                        weight_loader(param, loaded_weight[offset:new_offset], shard_id)
842
843
844
845
846
847
848
                    offset = new_offset

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
849
                is_expert_weight = False
850
851
852
853
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
854
855
856
857
858
859
860
861
                    # this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(name_mapped, self):
862
                        continue
863
864
865
866
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
867
868
869
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
870
                    success = weight_loader(
871
872
                        param,
                        loaded_weight,
873
                        name_mapped,
874
875
                        shard_id=shard_id,
                        expert_id=expert_id,
876
                        return_success=True,
877
                    )
878
879
880
                    if success:
                        name = name_mapped
                        break
881
                else:
882
883
884
885
886
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue
887
888
889
890
891
892
893
894
895
896
897
898
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    if "mlp.gate.wg." in name:
                        name = name.replace("wg.", "")

                    param = params_dict[name]
899
900
901
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
902
                    weight_loader(param, loaded_weight)
903
904
905
906
            loaded_params.add(name)
        return loaded_params


907
908
909
class HunyuanV1ModelBase(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
933
                config.vocab_size,
934
935
                config.hidden_size,
                quant_config=quant_config,
936
                prefix=maybe_prefix(prefix, "lm_head"),
937
938
939
940
941
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
942
            self.logits_processor = LogitsProcessor(
943
                config.vocab_size, scale=logit_scale
944
            )
945
946
947
        else:
            self.lm_head = PPMissingLayer()

948
949
    def forward(
        self,
950
        input_ids: torch.Tensor | None,
951
        positions: torch.Tensor,
952
953
954
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
955
956
957
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
958
959
960
961
962
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
963
    ) -> torch.Tensor | None:
964
965
966
967
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def make_empty_intermediate_tensors(
968
969
970
971
972
973
974
975
976
977
978
979
980
981
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
                "residual": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
982
983
        loader = AutoWeightsLoader(
            self,
984
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
985
986
987
        )
        return loader.load_weights(weights)

988
989
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
990
991
992
993
994
995


class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

996
997
998
        # Set MoE hyperparameters
        self.expert_weights = []
        self.num_expert_groups = 1
999
        self.moe_layers = []
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, HunYuanDecoderLayer)
            if isinstance(layer.mlp, HunYuanSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No HunYuanMoE layer found in model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
1028
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
1029
1030
1031
1032
1033
1034
1035
1036
        for layer in self.model.layers:
            if isinstance(layer.mlp, HunYuanSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

1037
1038
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()
1039

1040

1041
1042
1043
class HunYuanDenseV1Base(HunyuanV1ModelBase):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
1044
1045


1046
class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base):
1047
1048
1049
    pass


1050
class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base):
1051
    pass