hunyuan_v1.py 39.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# coding=utf-8
# Copyright 2024 The HunYuan team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
30
from typing import Any
31
32
33
34
35
36
37
38

import regex as re
import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
39
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
40
41
42
43
44
45
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
46
from vllm.model_executor.layers.activation import SiluAndMul
47
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
48
from vllm.model_executor.layers.layernorm import RMSNorm
49
50
51
52
53
54
55
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
56
from vllm.model_executor.layers.logits_processor import LogitsProcessor
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
59
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
60
61
62
    ParallelLMHead,
    VocabParallelEmbedding,
)
63
from vllm.model_executor.model_loader.weight_utils import (
64
65
66
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
67
68
from vllm.sequence import IntermediateTensors

69
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
70
71
72
73
74
75
76
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_layers,
    maybe_prefix,
)
77
78


79
80
81
82
83
84
85
86
87
88
89
90
91
def _is_moe(config: PretrainedConfig) -> bool:
    num_experts = getattr(config, "num_experts", None)
    if isinstance(num_experts, int):
        return num_experts > 1
    if isinstance(num_experts, list) and num_experts:
        # Ensure all elements are integers before calling max.
        if all(isinstance(e, int) for e in num_experts):
            return max(num_experts) > 1
        else:
            return False
    return False


92
93
94
95
96
97
98
99
100
101
102
103
def _get_cla_factor(config: PretrainedConfig) -> int:
    if not getattr(config, "use_cla", False):
        return 1
    return getattr(config, "cla_share_factor", 1)


class HunYuanMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
104
        quant_config: QuantizationConfig | None = None,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        bias: bool = False,
        prefix: str = "",
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            reduce_results=reduce_results,
        )
        if hidden_act != "silu":
126
127
128
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class HunYuanAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
146
        rope_scaling: dict[str, Any] | None = None,
147
        max_position_embeddings: int = 8192,
148
        quant_config: QuantizationConfig | None = None,
149
        bias: bool = False,
150
        cache_config: CacheConfig | None = None,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        prefix: str = "",
        layer_id: int = -1,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
170
171

        if hasattr(config, "head_dim") and config.head_dim:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            self.head_dim = config.head_dim
        elif hasattr(config, "attention_head_dim"):
            self.head_dim = config.attention_head_dim
        else:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.layer_id = layer_id

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

        if self.use_qk_norm:
222
223
            self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
224
225
226
227
228

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
229
        kv_states: tuple[torch.Tensor] | None = None,
230
231
232
233
234
235
236
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        ori_k = k
        if self.use_qk_norm:
            q = self.query_layernorm(
237
238
                q.view(-1, self.num_heads, self.head_dim).contiguous()
            )
239
            k = self.key_layernorm(
240
241
                k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
            )
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

        attn_output = self.attn(q, k, v)
        # For o_proj
        attn_output = attn_output.view(q.shape[0], -1)
        output, _ = self.o_proj(attn_output)
        return output, (ori_k, v)


class HunYuanCrossAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
258
        rope_scaling: dict[str, Any] | None = None,
259
        max_position_embeddings: int = 8192,
260
        quant_config: QuantizationConfig | None = None,
261
        bias: bool = False,
262
        cache_config: CacheConfig | None = None,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        prefix: str = "",
        layer_id: int = -1,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        if hasattr(config, "head_dim"):
            self.head_dim = config.head_dim
        elif hasattr(config, "attention_head_dim"):
            self.head_dim = config.attention_head_dim
        else:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.layer_id = layer_id

        self.q_proj = ColumnParallelLinear(
            hidden_size,
            hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            attn_type=AttentionType.ENCODER_DECODER,
        )

        if self.use_qk_norm:
333
334
            self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
335
336
337
338
339

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
340
        kv_states: tuple[torch.Tensor] | None = None,
341
342
343
344
345
346
347
348
349
    ) -> torch.Tensor:
        assert kv_states is not None
        ori_k, v = kv_states  # use last layer kv,
        k = ori_k
        q, _ = self.q_proj(hidden_states)
        k_tmp = torch.empty_like(k)  # Todo: reduant rotary embedding
        q, _ = self.rotary_emb(positions, q, k_tmp)
        if self.use_qk_norm:
            q = self.query_layernorm(
350
351
                q.view(-1, self.num_heads, self.head_dim).contiguous()
            )
352
            k = self.key_layernorm(
353
354
                k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
            )
355
356
357
358
359
360
361
362
363
364
365
366

        attn_output = self.attn(q, k, v)
        # For o_proj
        attn_output = attn_output.view(q.shape[0], -1)
        output, _ = self.o_proj(attn_output)
        return output, (ori_k, v)


class HunYuanSparseMoeBlock(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
367
        quant_config: QuantizationConfig | None = None,
368
369
        layer_id: int = -1,
        prefix: str = "",
370
        enable_eplb: bool = False,
371
372
373
374
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()

375
        self.ep_group = get_ep_group().device_group
376
        self.ep_rank = get_ep_group().rank_in_group
377
378
379
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

380
381
382
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
383
384
                f"the number of experts {config.num_experts}."
            )
385
386
387
388
389
390
391
392
393
394
395
396

        # Get layer_id topk if config.moe_topk is a list
        if isinstance(config.moe_topk, list):
            assert layer_id >= 0
            assert len(config.moe_topk) > layer_id
            top_k = config.moe_topk[layer_id]
        else:
            top_k = config.moe_topk

        # If it is moe, moe_intermediate_size is preferred
        intermediate_size = config.intermediate_size
        if config.moe_intermediate_size is not None:
397
398
399
400
401
            intermediate_size = (
                config.moe_intermediate_size
                if isinstance(config.moe_intermediate_size, int)
                else config.moe_intermediate_size[layer_id]
            )
402

403
404
405
406
407
408
409
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
        self.enable_eplb = enable_eplb

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
410
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
411
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size
412
413
414
415
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
416

417
418
419
420
421
422
423
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        if config.use_mixed_mlp_moe > 0:
            # Get layer_id num_shared_expert if config.num_shared_expert is
            # a list.
            if isinstance(config.num_shared_expert, list):
                assert layer_id >= 0
                assert len(config.num_shared_expert) > layer_id
                num_shared_expert = config.num_shared_expert[layer_id]
            else:
                num_shared_expert = config.num_shared_expert

            self.shared_mlp = HunYuanMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size * num_shared_expert,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
            )
        else:
            self.shared_mlp = None

444
445
446
447
448
449
450
451
452
453
454
455
456
457
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_mlp,
            num_experts=self.n_routed_experts,
            top_k=top_k,
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size,
            reduce_results=False,
            renormalize=top_k > 1,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
        )

458
459
460
461
462
463
464
465
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
466
467
468
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
469
470
471
        if self.shared_mlp is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]

472
        if self.tp_size > 1:
473
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
474
475
476
477
478
479
480
481

        return final_hidden_states.view(orig_shape)


class HunYuanDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
482
483
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
484
485
        prefix: str = "",
        layer_id: int = -1,
486
        enable_eplb: bool = False,
487
488
489
490
491
    ) -> None:
        super().__init__()
        assert layer_id >= 0
        self.layer_id = layer_id
        self.hidden_size = config.hidden_size
492
493
494
495
496
        self.intermediate_size = (
            config.intermediate_size
            if isinstance(config.intermediate_size, int)
            else config.intermediate_size[layer_id]
        )
497
498
499
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
500
501
            config, "original_max_position_embeddings", None
        ):
502
            rope_scaling["original_max_position_embeddings"] = (
503
504
505
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
506
        attention_bias = getattr(config, "attention_bias", False) or getattr(
507
508
            config, "bias", False
        )
509
        cla_factor = _get_cla_factor(config)
510
511
512
513
514
        attention_type = (
            AttentionType.ENCODER_DECODER
            if layer_id >= 0 and layer_id % cla_factor != 0
            else AttentionType.DECODER
        )
515
516
517
518
519
        if attention_type == AttentionType.DECODER:
            self.self_attn = HunYuanAttention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
520
521
522
                num_kv_heads=getattr(
                    config, "num_key_value_heads", config.num_attention_heads
                ),
523
524
525
526
527
528
529
530
531
532
533
534
535
536
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                bias=attention_bias,
                cache_config=cache_config,
                prefix=f"{prefix}.self_attn",
                layer_id=layer_id,
            )
        elif attention_type == AttentionType.ENCODER_DECODER:
            self.self_attn = HunYuanCrossAttention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
537
538
539
                num_kv_heads=getattr(
                    config, "num_key_value_heads", config.num_attention_heads
                ),
540
541
542
543
544
545
546
547
548
549
550
551
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                bias=attention_bias,
                cache_config=cache_config,
                prefix=f"{prefix}.self_attn",
                layer_id=layer_id,
            )
        else:
            raise RuntimeError(f"Unsupported attention type: {attention_type}")

552
553
554
555
556
557
        if _is_moe(config):
            self.mlp = HunYuanSparseMoeBlock(
                config=config,
                quant_config=quant_config,
                layer_id=layer_id,
                prefix=f"{prefix}.mlp",
558
                enable_eplb=enable_eplb,
559
560
561
562
563
564
565
566
567
568
569
            )
        else:
            self.mlp = HunYuanMLP(
                hidden_size=self.hidden_size,
                intermediate_size=self.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                bias=getattr(config, "mlp_bias", False),
                prefix=f"{prefix}.mlp",
            )

570
571
572
573
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
574
575
576
577
578

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
579
580
        residual: torch.Tensor | None,
        kv_states: tuple[torch.Tensor] | None = None,
581
582
583
584
585
586
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
587
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
588
589
590
591
592
593
594
        hidden_states, ori_kv_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_states=kv_states,
        )

        # Fully Connected
595
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
596
597
598
599
600
601
602
603
604
605
606
607
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual, ori_kv_states


@support_torch_compile
class HunYuanModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
608

609
610
611
        eplb_config = vllm_config.parallel_config.eplb_config
        enable_eplb = vllm_config.parallel_config.enable_eplb
        self.num_redundant_experts = eplb_config.num_redundant_experts
612
613
614
615

        self.config = config
        self.quant_config = quant_config
        self.padding_idx = config.pad_token_id
616
617
618

        self.vocab_size = config.vocab_size

619
620
621
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: HunYuanDecoderLayer(
                config=config,
                layer_id=int(prefix.split(".")[-1]),
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
637
                enable_eplb=enable_eplb,
638
639
640
641
642
643
644
645
            ),
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

646
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
647
648
649
650
        return self.embed_tokens(input_ids)

    def forward(
        self,
651
        input_ids: torch.Tensor | None,
652
        positions: torch.Tensor,
653
654
655
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
656
657
658
659
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
660
                hidden_states = self.embed_input_ids(input_ids)
661
662
663
664
665
666
667
668
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        cla_factor = _get_cla_factor(self.config)
        prev_kv_states = None
669
670
671
        for i, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer)
        ):
672
673
674
675
676
677
678
            hidden_states, residual, kv_states = layer(
                positions,
                hidden_states,
                residual,
                prev_kv_states,
            )

679
            if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
680
681
682
683
684
                prev_kv_states = kv_states
            else:
                prev_kv_states = None

        if not get_pp_group().is_last_rank:
685
686
687
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
688
689
690
691
692
693

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def _split_qkv_weight(self, qkv: torch.Tensor):
        num_attention_heads = self.config.num_attention_heads
694
695
696
        num_kv_heads = getattr(
            self.config, "num_key_value_heads", self.config.num_attention_heads
        )
697
698
699
700
701
702
703
704
705
706
        num_key_value_groups = num_attention_heads // num_kv_heads
        hidden_size = self.config.hidden_size

        if hasattr(self.config, "head_dim"):
            attention_head_dim = self.config.head_dim
        elif hasattr(self.config, "attention_head_dim"):
            attention_head_dim = self.config.attention_head_dim
        else:
            attention_head_dim = self.config.hidden_size // num_attention_heads

707
708
709
        qkv = qkv.reshape(
            num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
        )
710
711
712
713
714
715
        q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
        q = q.reshape(-1, hidden_size)
        k = k.reshape(-1, hidden_size)
        v = v.reshape(-1, hidden_size)
        return torch.concat((q, k, v))

716
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
717
718
719
        if _is_moe(self.config):
            # Params for weights, fp8 weight scales, fp8 activation scales
            # (param_name, weight_name, expert_id, shard_id)
720
            return SharedFusedMoE.make_expert_params_mapping(
721
722
723
724
                ckpt_gate_proj_name="gate_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="up_proj",
                num_experts=self.config.num_experts,
725
                num_redundant_experts=self.num_redundant_experts,
726
727
728
            )
        else:
            return []
729

730
731
732
733
734
735
736
737
738
739
740
741
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        cla_factor = _get_cla_factor(self.config)
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        num_attention_heads = self.config.num_attention_heads
742
743
744
        num_kv_heads = getattr(
            self.config, "num_key_value_heads", self.config.num_attention_heads
        )
745
746
747
748
749
750
        split_params_mapping = [
            (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
            (
                ".qkv_proj",
                ".qkv_proj",
                num_attention_heads + num_kv_heads * 2,
751
                [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
752
753
754
755
756
                self._split_qkv_weight,
            ),
        ]

        params_dict = dict(self.named_parameters())
757
758
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
759
760
761
762
763
764
765
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "gate_proj_bias" in name:
                name = name.replace("gate_proj_bias", "gate_proj.bias")
            if "up_proj_bias" in name:
                name = name.replace("up_proj_bias", "up_proj.bias")
766
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
767
768
769
770
771
772
773
774
775
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
            if self.quant_config is not None and (
776
777
                scale_name := self.quant_config.get_cache_scale(name)
            ):
778
779
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
780
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                continue

            is_found = False
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if "mlp.experts" in name:
                    continue
                # cross layer only have q_proj, skip qkv pack
                if weight_name == ".q_proj":
                    match = re.search(r"layers\.\d+", name)
                    if match:
                        layer_id = int(match.group(0).split(".")[-1])
                        if cla_factor > 1 and layer_id % cla_factor != 0:
                            continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
809
                loaded_params.add(name)
810
811
812
813
814
815
                is_found = True
                break
            if is_found:
                continue

            for (
816
817
818
819
820
                param_name,
                weight_name,
                den,
                split_param,
                func,
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
            ) in split_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                assert loaded_weight.shape[0] % den == 0
                units = loaded_weight.shape[0] // den

                param = params_dict[name]
                weight_loader = param.weight_loader
                offset = 0
                for shard_id, num in split_param:
                    new_offset = offset + num * units
                    if func:
841
842
843
                        weight_loader(
                            param, func(loaded_weight)[offset:new_offset], shard_id
                        )
844
                    else:
845
                        weight_loader(param, loaded_weight[offset:new_offset], shard_id)
846
847
848
849
850
851
852
                    offset = new_offset

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
853
                is_expert_weight = False
854
855
856
857
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
858
859
860
861
862
863
864
865
                    # this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(name_mapped, self):
866
                        continue
867
868
869
870
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
871
872
873
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
874
                    success = weight_loader(
875
876
                        param,
                        loaded_weight,
877
                        name_mapped,
878
879
                        shard_id=shard_id,
                        expert_id=expert_id,
880
                        return_success=True,
881
                    )
882
883
884
                    if success:
                        name = name_mapped
                        break
885
                else:
886
887
888
889
890
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue
891
892
893
894
895
896
897
898
899
900
901
902
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    if "mlp.gate.wg." in name:
                        name = name.replace("wg.", "")

                    param = params_dict[name]
903
904
905
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
906
                    weight_loader(param, loaded_weight)
907
908
909
910
            loaded_params.add(name)
        return loaded_params


911
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
935
                config.vocab_size,
936
937
                config.hidden_size,
                quant_config=quant_config,
938
                prefix=maybe_prefix(prefix, "lm_head"),
939
940
941
942
943
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
944
            self.logits_processor = LogitsProcessor(
945
                config.vocab_size, scale=logit_scale
946
            )
947
948
949
        else:
            self.lm_head = PPMissingLayer()

950
951
952
953
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
954
955
956
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
957
958
959
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
960
961
962
963
964
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
965
    ) -> torch.Tensor | None:
966
967
968
969
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def make_empty_intermediate_tensors(
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
                "residual": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
984
985
        loader = AutoWeightsLoader(
            self,
986
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
987
988
989
        )
        return loader.load_weights(weights)

990
991
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
992
993
994
995
996
997


class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

998
999
1000
        # Set MoE hyperparameters
        self.expert_weights = []
        self.num_expert_groups = 1
1001
        self.moe_layers = []
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, HunYuanDecoderLayer)
            if isinstance(layer.mlp, HunYuanSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No HunYuanMoE layer found in model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
1030
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
1031
1032
1033
1034
1035
1036
1037
1038
        for layer in self.model.layers:
            if isinstance(layer.mlp, HunYuanSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

1039
1040
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()
1041

1042

1043
1044
1045
class HunYuanDenseV1Base(HunyuanV1ModelBase):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
1046
1047


1048
class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base):
1049
1050
1051
    pass


1052
class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base):
1053
    pass