granitemoehybrid.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only GraniteMoeHybrid model."""
4

5
# Added by the IBM Team, 2025
6
from collections.abc import Iterable
7
8
9
10
11
12

import torch
from torch import nn
from transformers import GraniteMoeHybridConfig

from vllm.attention.layer import Attention
13
from vllm.compilation.decorators import support_torch_compile
14
from vllm.config import CacheConfig, ModelConfig, VllmConfig
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
17
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
21
from vllm.model_executor.layers.mamba.mamba_utils import (
22
23
24
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
25
26
27
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
    ParallelLMHead,
    VocabParallelEmbedding,
)
31
32
33
34
35
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
36
37
38
39
40
41
42
43
from .interfaces import (
    HasInnerState,
    IsHybrid,
    SupportsLoRA,
    SupportsMambaPrefixCaching,
    SupportsPP,
    SupportsQuant,
)
44
45
46
47
48
49
50
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
51
52
53


class GraniteMoeHybridMambaDecoderLayer(nn.Module):
54
55
56
57
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
58
59
60
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
61
62
        prefix: str = "",
    ) -> None:
63
64
65
66
67
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        self.mamba = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.mamba_d_state,
            conv_kernel_size=config.mamba_d_conv,
            intermediate_size=config.mamba_expand * config.hidden_size,
            use_conv_bias=config.mamba_conv_bias,
            use_bias=config.mamba_proj_bias,
            n_groups=config.mamba_n_groups,
            num_heads=config.mamba_n_heads,
            head_dim=config.mamba_d_head,
            rms_norm_eps=config.rms_norm_eps,
            activation=config.hidden_act,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.mixer",
        )
85

86
87
88
89
90
91
92
93
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
94
95
                prefix=f"{prefix}.block_sparse_moe",
            )
96

97
98
99
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
100
            else GraniteMoeSharedMLP(
101
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
102
            )
103
        )
104

105
106
107
108
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
109
110
111
112

    def forward(
        self,
        hidden_states: torch.Tensor,
113
        residual: torch.Tensor | None,
114
115
116
117
        **kwargs,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
118
        output = self.mamba(hidden_states)
119
        hidden_states = residual + output * self.residual_multiplier
120
121
122
123

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
124
125
126
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
127
128
        else:
            # create a copy since block_sparse_moe modifies in-place
129
130
131
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
132
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
133
134
135
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
136
137
138
139
140
141
142
143
144
145
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
146
147
148
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
149
150
151
152
153
154
155
156
157
158
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

        self.self_attn = GraniteMoeHybridAttention(
            config,
            cache_config=cache_config,
            quant_config=quant_config,
159
160
            prefix=f"{prefix}.self_attn",
        )
161

162
163
164
165
166
167
168
169
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
170
171
                prefix=f"{prefix}.block_sparse_moe",
            )
172

173
174
175
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
176
            else GraniteMoeSharedMLP(
177
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
178
            )
179
        )
180

181
182
183
184
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
185
186
187
188
189

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
190
        residual: torch.Tensor | None,
191
192
193
194
195
196
197
198
199
200
201
202
203
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
204
205
206
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
207
208
        else:
            # create a copy since block_sparse_moe modifies in-place
209
210
211
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
212
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
213
214
215
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
216
217
218
219
220
221
222
223
224
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttention(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
225
226
227
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
228
229
230
231
232
233
234
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.causal = True
        self.hidden_size = config.hidden_size
        self.attention_bias = config.attention_bias
        self.attention_multiplier = config.attention_multiplier
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads

        # TensorParallel logic
        tp_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)

253
254
255
256
257
258
259
260
261
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
262

263
264
265
266
267
268
269
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
270
271
272
273
274
275

        if config.position_embedding_type == "rope":
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=config.max_position_embeddings,
276
                rope_parameters=config.rope_parameters,
277
278
279
280
281
                is_neox_style=True,
            )
        else:
            self.rotary_emb = None

282
283
284
285
286
287
288
289
290
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.attention_multiplier,
            num_kv_heads=self.num_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
291
292
293
294
295
296

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
297
        qkv, _ = self.qkv_proj(hidden_states)
298
299
300
301
302
303
304
305
        query, key, value = qkv.split(
            [
                self.num_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
            ],
            dim=-1,
        )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        if self.rotary_emb is not None:
            query, key = self.rotary_emb(positions, query, key)

        hidden_states = self.attn(query, key, value)
        del query, key, value

        hidden_states = self.o_proj(hidden_states)[0]
        return hidden_states


ALL_DECODER_LAYER_TYPES = {
    "attention": GraniteMoeHybridAttentionDecoderLayer,
    "mamba": GraniteMoeHybridMambaDecoderLayer,
}


323
@support_torch_compile
324
325
326
327
328
class GraniteMoeHybridModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
329
        model_config = vllm_config.model_config
330
331
332
333
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
334
        self.quant_config = quant_config
335
336

        self.vocab_size = config.vocab_size
337
338
339
340
341
342
343
344
345

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.embedding_multiplier = config.embedding_multiplier

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
346
            layer_class = ALL_DECODER_LAYER_TYPES[config.layer_types[layer_idx]]
347
348
349
            return layer_class(
                config,
                layer_idx,
350
                model_config,
351
352
353
354
355
356
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
357
358
359
360
361
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
362
363
364

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

365
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
366
367
368
369
370
371
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
372
373
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
374
375
376
377
378
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
379
                hidden_states = self.embed_input_ids(input_ids)
380
381
382
383
                hidden_states = hidden_states * self.embedding_multiplier
            residual = None
        else:
            if intermediate_tensors is None:
384
                raise RuntimeError("Intermediate tensors may not be None!")
385
386
387
388
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        num_attn = 0
389
        for i, layer in enumerate(self.layers):
390
391
            if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
                num_attn += 1
392
393
394
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
395
396

        if not get_pp_group().is_last_rank:
397
398
399
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
400
401
402
403

        hidden_states = self.norm(hidden_states)
        return hidden_states

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        # layers.0.block_sparse_moe.expert_0.input_linear.input_scale
        ckpt_gate_proj_name = "gate_proj"
        ckpt_down_proj_name = "down_proj"
        ckpt_up_proj_name = "up_proj"
        num_experts = self.config.num_local_experts

        return [
            # (param_name, weight_name, expert_id, shard_id)
            (
                "block_sparse_moe.experts.w13_"
                if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
                else "block_sparse_moe.experts.w2_",
                f"block_sparse_moe.experts.{expert_id}.{weight_name}.",
                expert_id,
                shard_id,
            )
            for expert_id in range(num_experts)
            for shard_id, weight_name in [
                ("w1", ckpt_gate_proj_name),
                ("w2", ckpt_down_proj_name),
                ("w3", ckpt_up_proj_name),
            ]
        ]

431
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
432
433
434
435
436
437
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]
438
        params_dict = dict(self.named_parameters())
439
        loaded_params: set[str] = set()
440
        expert_params_mapping = self.get_expert_mapping()
441
442
443

        def _load(n, p):
            param = params_dict[n]
444
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
445
446
447
            weight_loader(param, p)
            loaded_params.add(n)

448
449
450
451
        def _load_shard(n, p, shard_id):
            # Skip layers on other devices.
            if not is_pp_missing_parameter(n, self):
                param = params_dict[n]
452
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
453
454
455
                weight_loader(param, p, shard_id)
                loaded_params.add(n)

456
457
        def _load_expert(n, p, name, shard_id, expert_id):
            param = params_dict[n]
458
459
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id)
460
461
            loaded_params.add(n)

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        def _load_quant_expert(name, loaded_weight):
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping

                if weight_name not in name:
                    continue

                name_mapped = name.replace(weight_name, param_name)

                # Skip layers on other devices.
                if is_pp_missing_parameter(name_mapped, self):
                    continue

                param = params_dict[name_mapped]
                weight_loader = param.weight_loader
                success = False

                if weight_loader is not None:
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )

                if success:
                    return name_mapped
            return None

493
494
495
496
        for n, p in weights:
            if "A_log" in n:
                n = n.replace("A_log", "A")

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(n)
            ):
                # Loading kv cache quantization scales
                loaded_weight = p
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                _load(scale_name, loaded_weight)
                loaded_params.add(scale_name)
                continue

            if _load_quant_expert(n, p):
                continue

512
513
514
515
            # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215
            # Mapping different experts' layout:
            #  from HF (input_linear, output_linear, router)
            #  to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate)
516
517
            # The renaming and parameter loading logic is the same for weight
            # and weight_scale tensors so we can reuse them without issues.
518
519
520
            if n.endswith(".block_sparse_moe.input_linear.weight") or n.endswith(
                ".block_sparse_moe.input_linear.weight_scale"
            ):
521
522
                for e in range(p.size(0)):
                    w1_name = n.replace(
523
524
525
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
526
                    w3_name = n.replace(
527
528
529
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
530
                    w1_param, w3_param = p[e].chunk(2, dim=0)
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w1_param,
                        w1_name,
                        shard_id="w1",
                        expert_id=e,
                    )
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w3_param,
                        w3_name,
                        shard_id="w3",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.output_linear.weight") or n.endswith(
                ".block_sparse_moe.output_linear.weight_scale"
            ):
548
549
                for e in range(p.size(0)):
                    w2_name = n.replace(
550
551
552
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
553
                    w2_param = p[e]
554
555
556
557
558
559
560
561
562
563
564
565
                    _load_expert(
                        n.replace(".output_linear.", ".experts.w2_"),
                        w2_param,
                        w2_name,
                        shard_id="w2",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
566
567
                _load(gate_name, p)
            else:
568
569
570
                loaded = False
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    if weight_name in n:
571
572
573
                        _load_shard(
                            n.replace(weight_name, param_name), p, shard_id=shard_id
                        )
574
575
576
                        loaded = True
                if not loaded:
                    _load(n, p)
577
578
579
580

        return loaded_params


581
class GraniteMoeHybridForCausalLM(
582
583
584
585
586
587
588
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    IsHybrid,
    SupportsQuant,
    SupportsMambaPrefixCaching,
589
):
590
591
592
593
594
595
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
596
597
598
        "conv1d": ["conv1d"],
        "in_proj": ["in_proj"],
        "input_linear": ["input_linear"],
599
    }
600
601
602
603
604
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

605
606
607
608
609
610
611
612
613
614
615
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.mamba_expand * hf_config.hidden_size

635
        return MambaStateShapeCalculator.mamba2_state_shape(
636
637
638
639
640
641
642
643
644
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.mamba_n_groups,
            num_heads=hf_config.mamba_n_heads,
            head_dim=hf_config.mamba_d_head,
            state_size=hf_config.mamba_d_state,
            conv_kernel=hf_config.mamba_d_conv,
        )

645
646
647
648
649
650
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
651

652
653
654
655
        scheduler_config = vllm_config.scheduler_config
        self.quant_config = vllm_config.quant_config
        self.config = config
        self.scheduler_config = scheduler_config
656
657
658
        self.model = GraniteMoeHybridModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
659
660

        self.lm_head = ParallelLMHead(
661
            config.vocab_size,
662
663
            config.hidden_size,
            quant_config=self.quant_config,
664
665
            prefix=maybe_prefix(prefix, "lm_head"),
        )
666
667
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
668
        self.logits_processor = LogitsProcessor(
669
            config.vocab_size,
670
671
672
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
673
674

        self.make_empty_intermediate_tensors = (
675
676
            self.model.make_empty_intermediate_tensors
        )
677

678
679
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
680

681
682
683
684
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
685
686
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
687
688
689
690
691
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
692
693
694
695
696
697

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
698
    ) -> torch.Tensor | None:
699
        logits = self.logits_processor(self.lm_head, hidden_states)
700
701
        return logits

702
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
703
704
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)