granitemoehybrid.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only GraniteMoeHybrid model."""
4

5
# Added by the IBM Team, 2025
6
from collections.abc import Iterable
7
8
9
10
11
12

import torch
from torch import nn
from transformers import GraniteMoeHybridConfig

from vllm.attention.layer import Attention
13
from vllm.compilation.decorators import support_torch_compile
14
from vllm.config import CacheConfig, ModelConfig, VllmConfig
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
17
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
21
from vllm.model_executor.layers.mamba.mamba_utils import (
22
23
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
24
25
26
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
27
28
29
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
31
32
    ParallelLMHead,
    VocabParallelEmbedding,
)
33
34
35
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
38
39
40
41
42
43
44
45
from .interfaces import (
    HasInnerState,
    IsHybrid,
    SupportsLoRA,
    SupportsMambaPrefixCaching,
    SupportsPP,
    SupportsQuant,
)
46
47
48
49
50
51
52
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
53
54
55


class GraniteMoeHybridMambaDecoderLayer(nn.Module):
56
57
58
59
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
60
61
62
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
63
64
        prefix: str = "",
    ) -> None:
65
66
67
68
69
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        self.mamba = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.mamba_d_state,
            conv_kernel_size=config.mamba_d_conv,
            intermediate_size=config.mamba_expand * config.hidden_size,
            use_conv_bias=config.mamba_conv_bias,
            use_bias=config.mamba_proj_bias,
            n_groups=config.mamba_n_groups,
            num_heads=config.mamba_n_heads,
            head_dim=config.mamba_d_head,
            rms_norm_eps=config.rms_norm_eps,
            activation=config.hidden_act,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.mixer",
        )
87

88
89
90
91
92
93
94
95
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
96
97
                prefix=f"{prefix}.block_sparse_moe",
            )
98

99
100
101
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
102
            else GraniteMoeSharedMLP(
103
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
104
            )
105
        )
106

107
108
109
110
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
111
112
113
114

    def forward(
        self,
        hidden_states: torch.Tensor,
115
        residual: torch.Tensor | None,
116
117
118
119
        **kwargs,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
120
        output = self.mamba(hidden_states)
121
        hidden_states = residual + output * self.residual_multiplier
122
123
124
125

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
126
127
128
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
129
130
        else:
            # create a copy since block_sparse_moe modifies in-place
131
132
133
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
134
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
135
136
137
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
138
139
140
141
142
143
144
145
146
147
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
148
149
150
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
151
152
153
154
155
156
157
158
159
160
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

        self.self_attn = GraniteMoeHybridAttention(
            config,
            cache_config=cache_config,
            quant_config=quant_config,
161
162
            prefix=f"{prefix}.self_attn",
        )
163

164
165
166
167
168
169
170
171
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
172
173
                prefix=f"{prefix}.block_sparse_moe",
            )
174

175
176
177
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
178
            else GraniteMoeSharedMLP(
179
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
180
            )
181
        )
182

183
184
185
186
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
187
188
189
190
191

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
192
        residual: torch.Tensor | None,
193
194
195
196
197
198
199
200
201
202
203
204
205
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
206
207
208
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
209
210
        else:
            # create a copy since block_sparse_moe modifies in-place
211
212
213
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
214
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
215
216
217
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
218
219
220
221
222
223
224
225
226
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttention(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
227
228
229
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
230
231
232
233
234
235
236
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.causal = True
        self.hidden_size = config.hidden_size
        self.attention_bias = config.attention_bias
        self.attention_multiplier = config.attention_multiplier
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads

        # TensorParallel logic
        tp_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)

255
256
257
258
259
260
261
262
263
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
264

265
266
267
268
269
270
271
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
272
273
274
275
276

        if config.position_embedding_type == "rope":
            self.rotary_emb = get_rope(
                self.head_dim,
                max_position=config.max_position_embeddings,
277
                rope_parameters=config.rope_parameters,
278
279
280
281
282
                is_neox_style=True,
            )
        else:
            self.rotary_emb = None

283
284
285
286
287
288
289
290
291
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.attention_multiplier,
            num_kv_heads=self.num_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
292
293
294
295
296
297

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
298
        qkv, _ = self.qkv_proj(hidden_states)
299
300
301
302
303
304
305
306
        query, key, value = qkv.split(
            [
                self.num_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
            ],
            dim=-1,
        )
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

        if self.rotary_emb is not None:
            query, key = self.rotary_emb(positions, query, key)

        hidden_states = self.attn(query, key, value)
        del query, key, value

        hidden_states = self.o_proj(hidden_states)[0]
        return hidden_states


ALL_DECODER_LAYER_TYPES = {
    "attention": GraniteMoeHybridAttentionDecoderLayer,
    "mamba": GraniteMoeHybridMambaDecoderLayer,
}


324
@support_torch_compile
325
326
327
328
329
class GraniteMoeHybridModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
330
        model_config = vllm_config.model_config
331
332
333
334
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
335
        self.quant_config = quant_config
336
337

        self.vocab_size = config.vocab_size
338
339
340
341
342
343
344
345
346

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.embedding_multiplier = config.embedding_multiplier

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
347
            layer_class = ALL_DECODER_LAYER_TYPES[config.layer_types[layer_idx]]
348
349
350
            return layer_class(
                config,
                layer_idx,
351
                model_config,
352
353
354
355
356
357
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
358
359
360
361
362
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
363
364
365

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

366
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
367
368
369
370
        return self.embed_tokens(input_ids)

    def forward(
        self,
371
        input_ids: torch.Tensor | None,
372
        positions: torch.Tensor,
373
374
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
375
376
377
378
379
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
380
                hidden_states = self.embed_input_ids(input_ids)
381
382
383
384
                hidden_states = hidden_states * self.embedding_multiplier
            residual = None
        else:
            if intermediate_tensors is None:
385
                raise RuntimeError("Intermediate tensors may not be None!")
386
387
388
389
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        num_attn = 0
390
        for i, layer in enumerate(self.layers):
391
392
            if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
                num_attn += 1
393
394
395
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
396
397

        if not get_pp_group().is_last_rank:
398
399
400
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
401
402
403
404

        hidden_states = self.norm(hidden_states)
        return hidden_states

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        # layers.0.block_sparse_moe.expert_0.input_linear.input_scale
        ckpt_gate_proj_name = "gate_proj"
        ckpt_down_proj_name = "down_proj"
        ckpt_up_proj_name = "up_proj"
        num_experts = self.config.num_local_experts

        return [
            # (param_name, weight_name, expert_id, shard_id)
            (
                "block_sparse_moe.experts.w13_"
                if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
                else "block_sparse_moe.experts.w2_",
                f"block_sparse_moe.experts.{expert_id}.{weight_name}.",
                expert_id,
                shard_id,
            )
            for expert_id in range(num_experts)
            for shard_id, weight_name in [
                ("w1", ckpt_gate_proj_name),
                ("w2", ckpt_down_proj_name),
                ("w3", ckpt_up_proj_name),
            ]
        ]

432
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
433
434
435
436
437
438
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]
439
        params_dict = dict(self.named_parameters())
440
        loaded_params: set[str] = set()
441
        expert_params_mapping = self.get_expert_mapping()
442
443
444

        def _load(n, p):
            param = params_dict[n]
445
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
446
447
448
            weight_loader(param, p)
            loaded_params.add(n)

449
450
451
452
        def _load_shard(n, p, shard_id):
            # Skip layers on other devices.
            if not is_pp_missing_parameter(n, self):
                param = params_dict[n]
453
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
454
455
456
                weight_loader(param, p, shard_id)
                loaded_params.add(n)

457
458
        def _load_expert(n, p, name, shard_id, expert_id):
            param = params_dict[n]
459
460
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id)
461
462
            loaded_params.add(n)

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        def _load_quant_expert(name, loaded_weight):
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping

                if weight_name not in name:
                    continue

                name_mapped = name.replace(weight_name, param_name)

                # Skip layers on other devices.
                if is_pp_missing_parameter(name_mapped, self):
                    continue

                param = params_dict[name_mapped]
                weight_loader = param.weight_loader
                success = False

                if weight_loader is not None:
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )

                if success:
                    return name_mapped
            return None

494
495
496
497
        for n, p in weights:
            if "A_log" in n:
                n = n.replace("A_log", "A")

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(n)
            ):
                # Loading kv cache quantization scales
                loaded_weight = p
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                _load(scale_name, loaded_weight)
                loaded_params.add(scale_name)
                continue

            if _load_quant_expert(n, p):
                continue

513
514
515
516
            # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215
            # Mapping different experts' layout:
            #  from HF (input_linear, output_linear, router)
            #  to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate)
517
518
            # The renaming and parameter loading logic is the same for weight
            # and weight_scale tensors so we can reuse them without issues.
519
520
521
            if n.endswith(".block_sparse_moe.input_linear.weight") or n.endswith(
                ".block_sparse_moe.input_linear.weight_scale"
            ):
522
523
                for e in range(p.size(0)):
                    w1_name = n.replace(
524
525
526
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
527
                    w3_name = n.replace(
528
529
530
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
531
                    w1_param, w3_param = p[e].chunk(2, dim=0)
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w1_param,
                        w1_name,
                        shard_id="w1",
                        expert_id=e,
                    )
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w3_param,
                        w3_name,
                        shard_id="w3",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.output_linear.weight") or n.endswith(
                ".block_sparse_moe.output_linear.weight_scale"
            ):
549
550
                for e in range(p.size(0)):
                    w2_name = n.replace(
551
552
553
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
554
                    w2_param = p[e]
555
556
557
558
559
560
561
562
563
564
565
566
                    _load_expert(
                        n.replace(".output_linear.", ".experts.w2_"),
                        w2_param,
                        w2_name,
                        shard_id="w2",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
567
568
                _load(gate_name, p)
            else:
569
570
571
                loaded = False
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    if weight_name in n:
572
573
574
                        _load_shard(
                            n.replace(weight_name, param_name), p, shard_id=shard_id
                        )
575
576
577
                        loaded = True
                if not loaded:
                    _load(n, p)
578
579
580
581

        return loaded_params


582
class GraniteMoeHybridForCausalLM(
583
584
585
586
587
588
589
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    IsHybrid,
    SupportsQuant,
    SupportsMambaPrefixCaching,
590
):
591
592
593
594
595
596
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
597
598
599
        "conv1d": ["conv1d"],
        "in_proj": ["in_proj"],
        "input_linear": ["input_linear"],
600
    }
601
602
603
604
605
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

606
607
608
609
610
611
612
613
614
615
616
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.mamba_expand * hf_config.hidden_size

636
        return MambaStateShapeCalculator.mamba2_state_shape(
637
638
639
640
641
642
643
644
645
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.mamba_n_groups,
            num_heads=hf_config.mamba_n_heads,
            head_dim=hf_config.mamba_d_head,
            state_size=hf_config.mamba_d_state,
            conv_kernel=hf_config.mamba_d_conv,
        )

646
647
648
649
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.mamba2_state_copy_func()

650
651
652
653
654
655
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
656

657
658
659
660
        scheduler_config = vllm_config.scheduler_config
        self.quant_config = vllm_config.quant_config
        self.config = config
        self.scheduler_config = scheduler_config
661
662
663
        self.model = GraniteMoeHybridModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
664
665

        self.lm_head = ParallelLMHead(
666
            config.vocab_size,
667
668
            config.hidden_size,
            quant_config=self.quant_config,
669
670
            prefix=maybe_prefix(prefix, "lm_head"),
        )
671
672
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
673
        self.logits_processor = LogitsProcessor(
674
            config.vocab_size,
675
676
677
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
678
679

        self.make_empty_intermediate_tensors = (
680
681
            self.model.make_empty_intermediate_tensors
        )
682

683
684
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
685

686
687
    def forward(
        self,
688
        input_ids: torch.Tensor | None,
689
        positions: torch.Tensor,
690
691
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
692
693
694
695
696
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
697
698
699
700
701
702

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
703
    ) -> torch.Tensor | None:
704
        logits = self.logits_processor(self.lm_head, hidden_states)
705
706
        return logits

707
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
708
709
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)