nemotron_h.py 34.4 KB
Newer Older
Luis Vega's avatar
Luis Vega committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Luis Vega's avatar
Luis Vega committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only NemotronH model."""
20

21
22
import typing
from collections.abc import Callable, Iterable
23
from itertools import islice
Luis Vega's avatar
Luis Vega committed
24
25
26
27

import torch
from torch import nn

28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, ModelConfig, VllmConfig
30
31
32
from vllm.config.parallel import ParallelConfig
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
Luis Vega's avatar
Luis Vega committed
33
34
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
35
from vllm.model_executor.layers.attention import Attention
36
from vllm.model_executor.layers.fused_moe import (
37
    GateLinear,
38
39
40
    SharedFusedMoE,
    activation_without_mul,
)
Luis Vega's avatar
Luis Vega committed
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
45
    ReplicatedLinear,
46
47
    RowParallelLinear,
)
Luis Vega's avatar
Luis Vega committed
48
from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
50
from vllm.model_executor.layers.mamba.mamba_utils import (
51
52
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
53
54
55
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
Luis Vega's avatar
Luis Vega committed
56
57
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
from vllm.model_executor.model_loader.weight_utils import (
62
63
64
65
66
67
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.interfaces import (
    HasInnerState,
    IsHybrid,
68
    MixtureOfExperts,
69
    SupportsLoRA,
70
    SupportsMambaPrefixCaching,
71
72
73
    SupportsPP,
    SupportsQuant,
)
Luis Vega's avatar
Luis Vega committed
74
from vllm.model_executor.models.utils import (
75
76
    AutoWeightsLoader,
    WeightsMapper,
77
    is_pp_missing_parameter,
78
79
80
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
81
    sequence_parallel_chunk,
82
)
Luis Vega's avatar
Luis Vega committed
83
from vllm.sequence import IntermediateTensors
84
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
Luis Vega's avatar
Luis Vega committed
85
86
87
88
89
90


class NemotronHMLP(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
shaharmor98's avatar
shaharmor98 committed
91
        hidden_size: int,
92
        intermediate_size: int,
93
        quant_config: QuantizationConfig | None = None,
Luis Vega's avatar
Luis Vega committed
94
        bias: bool = False,
95
96
        reduce_results: bool = True,
        is_sequence_parallel: bool = False,
97
        prefix: str = "",
Luis Vega's avatar
Luis Vega committed
98
99
    ) -> None:
        super().__init__()
100

101
        self.up_proj = ColumnParallelLinear(
shaharmor98's avatar
shaharmor98 committed
102
            input_size=hidden_size,
103
            output_size=intermediate_size,
Luis Vega's avatar
Luis Vega committed
104
105
            bias=bias,
            quant_config=quant_config,
106
            disable_tp=is_sequence_parallel,
107
            prefix=f"{prefix}.up_proj",
Luis Vega's avatar
Luis Vega committed
108
109
        )
        self.down_proj = RowParallelLinear(
110
            input_size=intermediate_size,
shaharmor98's avatar
shaharmor98 committed
111
            output_size=hidden_size,
Luis Vega's avatar
Luis Vega committed
112
113
            bias=bias,
            quant_config=quant_config,
114
115
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
116
            prefix=f"{prefix}.down_proj",
Luis Vega's avatar
Luis Vega committed
117
118
119
120
121
122
123
124
125
126
        )
        self.act_fn = ReLUSquaredActivation()

    def forward(self, x: torch.Tensor):
        x, _ = self.up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class NemotronHMoE(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        quant_config: QuantizationConfig | None = None,
        parallel_config: ParallelConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
shaharmor98's avatar
shaharmor98 committed
144
145
146
147
        self.use_latent_moe: bool = getattr(config, "moe_latent_size", None) is not None
        self.moe_hidden_size: int = (
            config.moe_latent_size if self.use_latent_moe else config.hidden_size
        )
148
149
150

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

151
        self.gate = GateLinear(
152
153
            config.hidden_size,
            config.n_routed_experts,
154
155
            out_dtype=torch.float32,
            force_fp32_compute=True,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            prefix=f"{prefix}.gate",
        )

        self.gate.e_score_correction_bias = nn.Parameter(
            torch.empty(config.n_routed_experts, dtype=torch.float32)
        )
        # Load balancing settings.
        self.enable_eplb = parallel_config.enable_eplb

        self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts  # noqa: E501
        self.n_logical_experts = self.n_routed_experts
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        if config.n_shared_experts is None or config.n_shared_experts == 0:
            self.shared_experts = None
        else:
            intermediate_size = (
                config.moe_shared_expert_intermediate_size * config.n_shared_experts
            )

            self.shared_experts = NemotronHMLP(
                config=config,
shaharmor98's avatar
shaharmor98 committed
184
                hidden_size=config.hidden_size,
185
186
187
188
189
190
191
                intermediate_size=intermediate_size,
                quant_config=quant_config,
                reduce_results=False,
                is_sequence_parallel=self.is_sequence_parallel,
                prefix=f"{prefix}.shared_experts",
            )

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        if self.use_latent_moe:
            self.fc1_latent_proj = ReplicatedLinear(
                input_size=config.hidden_size,
                output_size=self.moe_hidden_size,
                bias=config.mlp_bias,
                quant_config=quant_config,
                disable_tp=self.is_sequence_parallel,
                prefix=f"{prefix}.fc1_latent_proj",
            )
            self.fc2_latent_proj = ReplicatedLinear(
                input_size=self.moe_hidden_size,
                output_size=config.hidden_size,
                bias=config.mlp_bias,
                quant_config=quant_config,
                disable_tp=self.is_sequence_parallel,
                prefix=f"{prefix}.fc2_latent_proj",
            )
        else:
            self.fc1_latent_proj = None
            self.fc2_latent_proj = None

213
        self.experts = SharedFusedMoE(
214
            shared_experts=self.shared_experts,
215
216
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
shaharmor98's avatar
shaharmor98 committed
217
            hidden_size=self.moe_hidden_size,
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func="sigmoid",
            e_score_correction_bias=self.gate.e_score_correction_bias,
            activation=activation_without_mul(config.mlp_hidden_act),
            is_act_and_mul=False,  # non-gated MoE
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
233
            routed_input_transform=self.fc1_latent_proj,
234
            router_logits_dtype=self.gate.out_dtype,
235
236
237
238
239
240
241
242
243
244
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

        # router_logits: (num_tokens, n_experts)
245
        router_logits, _ = self.gate(hidden_states)
246

247
248
249
250
251
        # SharedFusedMoE handles:
        #   - shared experts (with original hidden_states)
        #   - routed_input_transform (fc1_latent_proj) for latent MoE
        #   - multistream parallelism between shared and routed experts
        shared_output, final_hidden_states = self.experts(
252
253
254
255
256
257
258
259
260
261
            hidden_states=hidden_states, router_logits=router_logits
        )

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
            final_hidden_states *= self.routed_scaling_factor
        elif self.shared_experts is not None:
            shared_output *= 1.0 / self.routed_scaling_factor

262
263
        # TODO: See SharedFusedMoE.apply_routed_input_transform
        # for bandwidth optimization
shaharmor98's avatar
shaharmor98 committed
264
265
266
        if self.use_latent_moe:
            final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        if self.shared_experts is not None:
            final_hidden_states += shared_output

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
                final_hidden_states, 0
            )
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )

        return final_hidden_states.view(num_tokens, hidden_dim)


Luis Vega's avatar
Luis Vega committed
283
284
285
286
287
class NemotronHMLPDecoderLayer(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
288
289
290
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
291
        parallel_config: ParallelConfig | None = None,
Luis Vega's avatar
Luis Vega committed
292
293
294
295
296
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config

297
298
        hybrid_override_pattern = config.hybrid_override_pattern
        mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
299
300
301
302
303
        # Get per-layer config for heterogeneous models if exist
        get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
        layer_config = get_layer_config(layer_idx) if get_layer_config else config
        config = layer_config

304
305
306
307
308
309
310
311
        if isinstance(config.intermediate_size, list):
            if len(config.intermediate_size) == 1:
                intermediate_size = config.intermediate_size[0]
            else:
                intermediate_size = config.intermediate_size[mlp_index]
        else:
            intermediate_size = config.intermediate_size

312
313
        self.mixer = NemotronHMLP(
            config,
shaharmor98's avatar
shaharmor98 committed
314
            hidden_size=config.hidden_size,
315
            intermediate_size=intermediate_size,
316
317
318
319
            quant_config=quant_config,
            bias=config.mlp_bias,
            prefix=f"{prefix}.mixer",
        )
Luis Vega's avatar
Luis Vega committed
320

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor | None,
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        hidden_states = self.mixer(hidden_states)
        return hidden_states, residual


class NemotronHMoEDecoderLayer(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        parallel_config: ParallelConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config

Jiayi Yan's avatar
Jiayi Yan committed
353
        # Get per-layer config for heterogeneous models if exists
354
355
356
        get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
        layer_config = get_layer_config(layer_idx) if get_layer_config else config

357
        self.mixer = NemotronHMoE(
358
            layer_config,
359
360
361
362
363
364
            quant_config=quant_config,
            parallel_config=parallel_config,
            prefix=f"{prefix}.mixer",
        )

        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Luis Vega's avatar
Luis Vega committed
365
366
367
368

    def forward(
        self,
        hidden_states: torch.Tensor,
369
        residual: torch.Tensor | None,
Luis Vega's avatar
Luis Vega committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        hidden_states = self.mixer(hidden_states)
        return hidden_states, residual


class NemotronHMambaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
387
388
389
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
390
        parallel_config: ParallelConfig | None = None,
Luis Vega's avatar
Luis Vega committed
391
392
393
394
395
396
397
398
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.mixer = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.ssm_state_size,
            conv_kernel_size=config.conv_kernel,
399
            intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
Luis Vega's avatar
Luis Vega committed
400
401
402
403
404
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            n_groups=config.n_groups,
            num_heads=config.mamba_num_heads,
            head_dim=config.mamba_head_dim,
405
            rms_norm_eps=config.layer_norm_epsilon,
Luis Vega's avatar
Luis Vega committed
406
            activation=config.mamba_hidden_act,
407
408
            model_config=model_config,
            cache_config=cache_config,
Luis Vega's avatar
Luis Vega committed
409
            quant_config=quant_config,
410
            prefix=f"{prefix}.mixer",
Luis Vega's avatar
Luis Vega committed
411
412
        )

413
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Luis Vega's avatar
Luis Vega committed
414
415
416
417

    def forward(
        self,
        hidden_states: torch.Tensor,
418
        residual: torch.Tensor | None,
Luis Vega's avatar
Luis Vega committed
419
420
421
422
423
424
425
426
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

427
        output = self.mixer(hidden_states)
428
        return output, residual
Luis Vega's avatar
Luis Vega committed
429
430
431
432
433
434
435


class NemotronHAttention(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
436
437
438
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
Luis Vega's avatar
Luis Vega committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
457
458
459
460
        if hasattr(config, "head_dim") and config.head_dim is not None:
            self.head_dim = config.head_dim
        else:
            self.head_dim = config.hidden_size // self.total_num_heads
Luis Vega's avatar
Luis Vega committed
461
462
463
464
465
466
467
468
469
470
471
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
472
            prefix=f"{prefix}.qkv_proj",
Luis Vega's avatar
Luis Vega committed
473
474
475
476
477
478
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
479
            prefix=f"{prefix}.o_proj",
Luis Vega's avatar
Luis Vega committed
480
481
        )

482
483
484
        # Get per-layer sliding window from config (for heterogeneous models)
        sliding_window = getattr(config, "sliding_window", None)

Luis Vega's avatar
Luis Vega committed
485
486
487
488
489
490
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
491
            quant_config=quant_config,
Luis Vega's avatar
Luis Vega committed
492
            prefix=f"{prefix}.attn",
493
            per_layer_sliding_window=sliding_window,
Luis Vega's avatar
Luis Vega committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class NemotronHAttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
513
514
515
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
516
        parallel_config: ParallelConfig | None = None,
Luis Vega's avatar
Luis Vega committed
517
518
519
520
        prefix: str = "",
    ) -> None:
        super().__init__()

Jiayi Yan's avatar
Jiayi Yan committed
521
        # Get per-layer config for heterogeneous models if exists
522
523
524
        get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
        layer_config = get_layer_config(layer_idx) if get_layer_config else config

Luis Vega's avatar
Luis Vega committed
525
        self.mixer = NemotronHAttention(
526
            layer_config,
Luis Vega's avatar
Luis Vega committed
527
            layer_idx,
528
            model_config,
Luis Vega's avatar
Luis Vega committed
529
530
            cache_config,
            quant_config,
531
            prefix=f"{prefix}.mixer",
Luis Vega's avatar
Luis Vega committed
532
533
        )

534
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Luis Vega's avatar
Luis Vega committed
535
536
537
538
539

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
540
        residual: torch.Tensor | None,
Luis Vega's avatar
Luis Vega committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        hidden_states = self.mixer(hidden_states=hidden_states)
        return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
    "M": NemotronHMambaDecoderLayer,
    "-": NemotronHMLPDecoderLayer,
    "*": NemotronHAttentionDecoderLayer,
557
    "E": NemotronHMoEDecoderLayer,
Luis Vega's avatar
Luis Vega committed
558
559
560
}


561
@support_torch_compile
Luis Vega's avatar
Luis Vega committed
562
563
564
565
566
class NemotronHModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: NemotronHConfig = vllm_config.model_config.hf_config
567
        model_config = vllm_config.model_config
Luis Vega's avatar
Luis Vega committed
568
569
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
570
        parallel_config = vllm_config.parallel_config
Luis Vega's avatar
Luis Vega committed
571
572

        self.config = config
573
574

        self.vocab_size = config.vocab_size
Luis Vega's avatar
Luis Vega committed
575
576
577
578
579
580

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

581
582
        self.has_moe = "E" in config.hybrid_override_pattern

Luis Vega's avatar
Luis Vega committed
583
584
585
        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
            layer_class = ALL_DECODER_LAYER_TYPES[
586
587
                config.hybrid_override_pattern[layer_idx]
            ]
Luis Vega's avatar
Luis Vega committed
588
            return layer_class(
589
590
591
592
                config=config,
                layer_idx=layer_idx,
                model_config=model_config,
                cache_config=cache_config,
Luis Vega's avatar
Luis Vega committed
593
                quant_config=quant_config,
594
                parallel_config=parallel_config,
Luis Vega's avatar
Luis Vega committed
595
596
597
598
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
599
600
            len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
        )
601
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
602
603
            ["hidden_states", "residual"], config.hidden_size
        )
Luis Vega's avatar
Luis Vega committed
604

605
        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Luis Vega's avatar
Luis Vega committed
606

607
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
Luis Vega's avatar
Luis Vega committed
608
609
610
611
        return self.embed_tokens(input_ids)

    def forward(
        self,
612
        input_ids: torch.Tensor | None,
Luis Vega's avatar
Luis Vega committed
613
        positions: torch.Tensor,
614
615
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
616
    ) -> torch.Tensor | IntermediateTensors:
Luis Vega's avatar
Luis Vega committed
617
618
619
620
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
621
                hidden_states = self.embed_input_ids(input_ids)
Luis Vega's avatar
Luis Vega committed
622
623
624
625
626
627
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

628
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Luis Vega's avatar
Luis Vega committed
629
630
631
632
633
634
635
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
636
637
638
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
Luis Vega's avatar
Luis Vega committed
639
640
641
        hidden_states, _ = self.norm_f(hidden_states, residual)
        return hidden_states

642
643
644
    def is_spec_layer(self, config: NemotronHConfig, weight_name: str) -> bool:
        return weight_name.startswith("mtp.")

645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    def _get_max_n_routed_experts(self) -> int:
        """Get max n_routed_experts from config or block_configs for puzzle models.

        For heterogeneous models with varying expert counts per layer,
        returns the MAX to ensure all expert weights can be loaded.
        """
        # First try top-level attribute
        n_routed_experts = getattr(self.config, "n_routed_experts", None)
        if n_routed_experts is not None:
            return n_routed_experts

        # For puzzle models, get MAX from all MoE blocks in block_configs
        # (different layers may have different expert counts)
        max_experts = 0
        block_configs = getattr(self.config, "block_configs", None)
        if block_configs:
            for block in block_configs:
                if isinstance(block, dict):
                    if block.get("block_type") == "moe":
                        max_experts = max(max_experts, block.get("n_routed_experts", 0))
                else:
                    # HF converts dicts to objects with attributes
                    if getattr(block, "block_type", "") == "moe":
                        max_experts = max(
                            max_experts, getattr(block, "n_routed_experts", 0)
                        )
        return max_experts

673
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
674
675
        if self.has_moe:
            # (param_name, weight_name, expert_id, shard_id)
676
            expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
677
678
679
680
                # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
                #   what the activation is applied to
                # - FusedMoe.w3 (aka up_proj) should be ignored since we're
                #   using non-gated MoE
681
                self,
682
683
684
                ckpt_gate_proj_name="up_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="",
685
                num_experts=self._get_max_n_routed_experts(),
686
687
                num_redundant_experts=getattr(self, "num_redundant_experts", 0),
            )
688
689
690
691
692
693
694
695
696
697
698
699
700
            return expert_params_mapping

        return []

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        expert_params_mapping = self.get_expert_mapping()
701

Luis Vega's avatar
Luis Vega committed
702
703
704
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
705
            if "scale" in name or "zero_point" in name:
706
707
708
709
710
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

711
712
713
714
            # Skip MTP/spec decode layers early (before stacked params mapping)
            if name.startswith("mtp."):
                continue

715
716
717
718
719
720
721
722
            # load stacked params
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
Luis Vega's avatar
Luis Vega committed
723

724
725
726
                if is_pp_missing_parameter(name, self):
                    continue

Luis Vega's avatar
Luis Vega committed
727
728
                param = params_dict[name]
                weight_loader = param.weight_loader
729
730
731
                weight_loader(param, loaded_weight, shard_id)
                break

Luis Vega's avatar
Luis Vega committed
732
733
            # load other params
            else:
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
                is_expert_weight = False
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
                        continue
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
                    if success:
                        name = name_mapped
                        break
                else:
                    if is_expert_weight:
                        continue

772
773
774
                    if is_pp_missing_parameter(name, self):
                        continue

775
776
777
778
779
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
Luis Vega's avatar
Luis Vega committed
780
781
782
783
784

            loaded_params.add(name)
        return loaded_params


785
class NemotronHForCausalLM(
786
787
788
789
790
791
792
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    IsHybrid,
    SupportsQuant,
    MixtureOfExperts,
793
    SupportsMambaPrefixCaching,
794
):
795
796
797
    # Relevant only if self.has_moe is True
    is_non_gated_moe: bool = True

798
799
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"backbone": "model"},
800
        orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"},
801
802
    )

Luis Vega's avatar
Luis Vega committed
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

817
818
819
    # Skip MTP (Multi-Token Prediction) layers during LoRA loading
    lora_skip_prefixes = ["mtp."]

820
821
822
823
824
825
826
827
828
829
830
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
848
        intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
849

850
        return MambaStateShapeCalculator.mamba2_state_shape(
851
852
853
854
855
856
857
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.mamba_num_heads,
            head_dim=hf_config.mamba_head_dim,
            state_size=hf_config.ssm_state_size,
            conv_kernel=hf_config.conv_kernel,
858
            num_spec=vllm_config.num_speculative_tokens,
859
860
        )

861
862
863
864
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.mamba2_state_copy_func()

Luis Vega's avatar
Luis Vega committed
865
866
867
868
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
869

Luis Vega's avatar
Luis Vega committed
870
871
872
873
874
875
876
        scheduler_config = vllm_config.scheduler_config

        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
877
878
879
        self.model = NemotronHModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
880

Luis Vega's avatar
Luis Vega committed
881
        self.lm_head = ParallelLMHead(
882
            config.vocab_size,
Luis Vega's avatar
Luis Vega committed
883
            config.hidden_size,
884
            prefix=maybe_prefix(prefix, "lm_head"),
Luis Vega's avatar
Luis Vega committed
885
886
        )

887
        self.logits_processor = LogitsProcessor(config.vocab_size)
Luis Vega's avatar
Luis Vega committed
888

889
890
891
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )
Luis Vega's avatar
Luis Vega committed
892

893
894
895
896
897
        # Set MoE hyperparameters
        if self.model.has_moe:
            self.expert_weights = []
            self.num_expert_groups = config.n_group

898
            self.moe_layers = []
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
            example_moe = None
            for layer in self.model.layers:
                if isinstance(layer, NemotronHMoEDecoderLayer):
                    # Pick last one layer since the first ones
                    # may be dense layers.
                    example_moe = layer.mixer
                    self.moe_layers.append(layer.mixer.experts)

            self.num_moe_layers = len(self.moe_layers)
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts  # noqa: E501
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer, NemotronHMoEDecoderLayer):
                moe = layer.mixer
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

932
933
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
Luis Vega's avatar
Luis Vega committed
934

935
936
    def forward(
        self,
937
        input_ids: torch.Tensor | None,
938
        positions: torch.Tensor,
939
940
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
941
942
943
944
945
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
Luis Vega's avatar
Luis Vega committed
946
947
948
949
950
951

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
952
    ) -> torch.Tensor | None:
953
        logits = self.logits_processor(self.lm_head, hidden_states)
Luis Vega's avatar
Luis Vega committed
954
955
        return logits

956
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shaharmor98's avatar
shaharmor98 committed
957
        loader = AutoWeightsLoader(self, skip_prefixes=["mtp"])
958
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)