nemotron_h.py 22.9 KB
Newer Older
Luis Vega's avatar
Luis Vega committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Luis Vega's avatar
Luis Vega committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional

import torch
from torch import nn

26
from vllm import envs
Luis Vega's avatar
Luis Vega committed
27
from vllm.attention.layer import Attention
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, ModelConfig, VllmConfig
30
from vllm.distributed import get_tensor_model_parallel_world_size
Luis Vega's avatar
Luis Vega committed
31
32
33
34
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Luis Vega's avatar
Luis Vega committed
36
37
38
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
    Mamba2Metadata, prepare_mamba2_metadata)
41
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
42
from vllm.model_executor.layers.mamba.mamba_utils import (
43
    MambaStateDtypeCalculator, MambaStateShapeCalculator)
Luis Vega's avatar
Luis Vega committed
44
45
46
47
48
49
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
                                                   SupportsLoRA, SupportsPP,
50
                                                   SupportsQuant)
Luis Vega's avatar
Luis Vega committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
    maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType


class NemotronHMLP(nn.Module):

    def __init__(
        self,
        config: NemotronHConfig,
67
        layer_idx: int,
Luis Vega's avatar
Luis Vega committed
68
69
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
70
        prefix: str = "",
Luis Vega's avatar
Luis Vega committed
71
72
    ) -> None:
        super().__init__()
73
74
75
76
77
78
79
80
81
82
83

        hybrid_override_pattern = config.hybrid_override_pattern
        mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
        if isinstance(config.intermediate_size, list):
            if len(config.intermediate_size) == 1:
                intermediate_size = config.intermediate_size[0]
            else:
                intermediate_size = config.intermediate_size[mlp_index]
        else:
            intermediate_size = config.intermediate_size

84
        self.up_proj = ColumnParallelLinear(
Luis Vega's avatar
Luis Vega committed
85
            input_size=config.hidden_size,
86
            output_size=intermediate_size,
Luis Vega's avatar
Luis Vega committed
87
88
            bias=bias,
            quant_config=quant_config,
89
            prefix=f"{prefix}.up_proj",
Luis Vega's avatar
Luis Vega committed
90
91
        )
        self.down_proj = RowParallelLinear(
92
            input_size=intermediate_size,
Luis Vega's avatar
Luis Vega committed
93
94
95
            output_size=config.hidden_size,
            bias=bias,
            quant_config=quant_config,
96
            prefix=f"{prefix}.down_proj",
Luis Vega's avatar
Luis Vega committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        )
        self.act_fn = ReLUSquaredActivation()

    def forward(self, x: torch.Tensor):
        x, _ = self.up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


class NemotronHMLPDecoderLayer(nn.Module):

    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
113
        model_config: Optional[ModelConfig] = None,
Luis Vega's avatar
Luis Vega committed
114
115
116
117
118
119
120
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config

121
122
123
124
125
        self.mixer = NemotronHMLP(
            config,
            quant_config=quant_config,
            bias=config.mlp_bias,
            prefix=f"{prefix}.mixer",
126
            layer_idx=layer_idx,
127
        )
Luis Vega's avatar
Luis Vega committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        hidden_states = self.mixer(hidden_states)
        return hidden_states, residual


class NemotronHMambaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
153
        model_config: Optional[ModelConfig] = None,
Luis Vega's avatar
Luis Vega committed
154
155
156
157
158
159
160
161
162
163
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.mixer = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.ssm_state_size,
            conv_kernel_size=config.conv_kernel,
164
            intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
Luis Vega's avatar
Luis Vega committed
165
166
167
168
169
170
171
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            n_groups=config.n_groups,
            num_heads=config.mamba_num_heads,
            head_dim=config.mamba_head_dim,
            rms_norm_eps=config.rms_norm_eps,
            activation=config.mamba_hidden_act,
172
173
            model_config=model_config,
            cache_config=cache_config,
Luis Vega's avatar
Luis Vega committed
174
            quant_config=quant_config,
175
            prefix=f"{prefix}.mixer",
Luis Vega's avatar
Luis Vega committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        )

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

194
195
196
        output = torch.empty_like(hidden_states)
        self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
        return output, residual
Luis Vega's avatar
Luis Vega committed
197
198
199
200
201
202
203
204


class NemotronHAttention(nn.Module):

    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
205
        model_config: Optional[ModelConfig] = None,
Luis Vega's avatar
Luis Vega committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
226
227
228
229
        if hasattr(config, "head_dim") and config.head_dim is not None:
            self.head_dim = config.head_dim
        else:
            self.head_dim = config.hidden_size // self.total_num_heads
Luis Vega's avatar
Luis Vega committed
230
231
232
233
234
235
236
237
238
239
240
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
241
            prefix=f"{prefix}.qkv_proj",
Luis Vega's avatar
Luis Vega committed
242
243
244
245
246
247
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
248
            prefix=f"{prefix}.o_proj",
Luis Vega's avatar
Luis Vega committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class NemotronHAttentionDecoderLayer(nn.Module):

    def __init__(
        self,
        config: NemotronHConfig,
        layer_idx: int,
278
        model_config: Optional[ModelConfig] = None,
Luis Vega's avatar
Luis Vega committed
279
280
281
282
283
284
285
286
287
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.mixer = NemotronHAttention(
            config,
            layer_idx,
288
            model_config,
Luis Vega's avatar
Luis Vega committed
289
290
            cache_config,
            quant_config,
291
            prefix=f"{prefix}.mixer",
Luis Vega's avatar
Luis Vega committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        )

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        hidden_states = self.mixer(hidden_states=hidden_states)
        return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
    "M": NemotronHMambaDecoderLayer,
    "-": NemotronHMLPDecoderLayer,
    "*": NemotronHAttentionDecoderLayer,
}


320
@support_torch_compile
Luis Vega's avatar
Luis Vega committed
321
322
323
324
325
326
class NemotronHModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: NemotronHConfig = vllm_config.model_config.hf_config
327
        model_config = vllm_config.model_config
Luis Vega's avatar
Luis Vega committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
            layer_class = ALL_DECODER_LAYER_TYPES[
                config.hybrid_override_pattern[layer_idx]]
            return layer_class(
                config,
                layer_idx,
351
                model_config,
Luis Vega's avatar
Luis Vega committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
            len(config.hybrid_override_pattern),
            get_layer,
            prefix=f"{prefix}.layers")
        self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size)

        self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        attn_metadata = get_forward_context().attn_metadata

380
381
382
383
384
385
386
387
        if not envs.VLLM_USE_V1:
            mamba2_metadata = prepare_mamba2_metadata(
                chunk_size=self.config.chunk_size,
                attn_metadata=attn_metadata,
            )
        else:
            # v1 get mamba2_metadata from forward_context
            mamba2_metadata = None
Luis Vega's avatar
Luis Vega committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        residual = None
        num_non_mamba_layers = 0
402
        for i, layer in enumerate(self.layers):
Luis Vega's avatar
Luis Vega committed
403
            layer_mamba_cache_params = None
404
405
            if isinstance(layer,
                          NemotronHMambaDecoderLayer) and mamba_cache_params:
Luis Vega's avatar
Luis Vega committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
                    i - num_non_mamba_layers)
            else:
                num_non_mamba_layers += 1

            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
                mamba_cache_params=layer_mamba_cache_params,
                mamba2_metadata=mamba2_metadata,
            )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        hidden_states, _ = self.norm_f(hidden_states, residual)
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        attb_params_mapping = {
            "q_proj": "q",
            "k_proj": "k",
            "v_proj": "v",
        }

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "embeddings" in name:
                name = name.replace("embeddings", "embed_tokens")

            if "A_log" in name:
                name = name.replace("A_log", "A")
                loaded_weight = loaded_weight.to(torch.float32)

            if "D" in name:
                loaded_weight = loaded_weight.to(torch.float32)

            if "dt_bias" in name:
                loaded_weight = loaded_weight.to(torch.float32)

            # load attn params
            if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
                weight_name = next(proj
                                   for proj in ["q_proj", "k_proj", "v_proj"]
                                   if proj in name)
                name = name.replace(weight_name, "qkv_proj")
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight,
                              attb_params_mapping[weight_name])
            # load other params
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

            loaded_params.add(name)
        return loaded_params


class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
473
                           IsHybrid, SupportsQuant):
Luis Vega's avatar
Luis Vega committed
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

489
490
491
492
493
494
495
496
497
498
499
500
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:

        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config
            use_v1: Get shapes for V1 (or V0)

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
520
        intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
521

522
        return MambaStateShapeCalculator.mamba2_state_shape(
523
524
525
526
527
528
529
530
531
532
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.mamba_num_heads,
            head_dim=hf_config.mamba_head_dim,
            state_size=hf_config.ssm_state_size,
            conv_kernel=hf_config.conv_kernel,
            use_v1=use_v1,
        )

Luis Vega's avatar
Luis Vega committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
        assert not cache_config.enable_prefix_caching, \
            "NemotronH currently does not support prefix caching"

        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
        self.model = NemotronHModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "model"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

        self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs):

580
581
582
583
584
585
586
587
588
        mamba_cache_params = None
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:

                num_mamba_layers = \
                    self.model_config.get_num_layers_by_block_type(
                        self.vllm_config.parallel_config,
                        LayerBlockType.mamba
                    )
589
590
591
                mamba_state_shape = \
                    self.get_mamba_state_shape_from_config(
                        self.vllm_config, use_v1=False)
592
593
594
                mamba_state_dtype = \
                    self.get_mamba_state_dtype_from_config(
                    self.vllm_config)
595
596
                self.mamba_cache = MambaCacheManager(self.vllm_config,
                                                     num_mamba_layers,
597
598
                                                     *mamba_state_shape,
                                                     *mamba_state_dtype)
599
600

            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
Luis Vega's avatar
Luis Vega committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

        hidden_states = self.model(input_ids, positions, mamba_cache_params,
                                   intermediate_tensors, inputs_embeds)

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        # update name in weights before passing to loader
        updated_weights = []
        for name, loaded_weight in weights:
            name = name.replace("backbone", "model")
            updated_weights.append((name, loaded_weight))
        loader = AutoWeightsLoader(self)
        return loader.load_weights(updated_weights)