jamba.py 23.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Jamba model."""
4
5
from collections.abc import Iterable
from typing import Optional
Mor Zusman's avatar
Mor Zusman committed
6
7
8
9
10
11

import torch
from torch import nn
from transformers import JambaConfig

from vllm.attention.layer import Attention
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.distributed.parallel_state import get_pp_group
15
from vllm.model_executor.layers.fused_moe import FusedMoE
Mor Zusman's avatar
Mor Zusman committed
16
from vllm.model_executor.layers.layernorm import RMSNorm
17
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Mor Zusman's avatar
Mor Zusman committed
18
19
20
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
22
23
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
                                               PoolingType)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Mor Zusman's avatar
Mor Zusman committed
25
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
29
30
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
Mor Zusman's avatar
Mor Zusman committed
31
from vllm.model_executor.sampling_metadata import SamplingMetadata
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils import LayerBlockType
Mor Zusman's avatar
Mor Zusman committed
34

35
36
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
                         SupportsV0Only)
37
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
38
39
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
40

Mor Zusman's avatar
Mor Zusman committed
41
42
43

class JambaMoE(nn.Module):

44
45
46
47
48
49
    def __init__(self,
                 config: JambaConfig,
                 num_experts: Optional[int] = None,
                 top_k: Optional[int] = None,
                 params_dtype: Optional[torch.dtype] = None,
                 tp_size: Optional[int] = None,
50
51
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
Mor Zusman's avatar
Mor Zusman committed
52
        super().__init__()
53
54
        self.num_total_experts = num_experts or config.num_experts
        self.top_k = top_k or config.num_experts_per_tok
Mor Zusman's avatar
Mor Zusman committed
55
        self.hidden_size = config.hidden_size
56
        self.intermediate_size = config.intermediate_size
Mor Zusman's avatar
Mor Zusman committed
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        if self.num_total_experts > 1:
            self.router = ReplicatedLinear(self.hidden_size,
                                           self.num_total_experts,
                                           bias=False,
                                           quant_config=None,
                                           params_dtype=params_dtype)

        self.experts = FusedMoE(self.num_total_experts,
                                self.top_k,
                                self.hidden_size,
                                self.intermediate_size,
                                tp_size=tp_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=False,
                                use_grouped_topk=False,
74
75
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts")
Mor Zusman's avatar
Mor Zusman committed
76
77

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
78
        orig_shape = hidden_states.shape
Mor Zusman's avatar
Mor Zusman committed
79
80
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (batch * sequence_length, n_experts)
81
82
83
84
85
86
87
88
        if self.num_total_experts > 1:
            router_logits, _ = self.router(hidden_states)
        else:
            router_logits = torch.ones((hidden_states.shape[0], 1),
                                       device=hidden_states.device,
                                       dtype=hidden_states.dtype)
        hidden_states = self.experts(hidden_states, router_logits)
        return hidden_states.view(orig_shape)
Mor Zusman's avatar
Mor Zusman committed
89
90
91
92
93
94
95
96


class JambaMambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: JambaConfig,
                 layer_idx: int,
                 cache_config: Optional[CacheConfig] = None,
97
                 quant_config: Optional[QuantizationConfig] = None,
98
                 is_lora_enabled: Optional[bool] = False,
99
                 prefix: str = "",
100
                 **kwargs) -> None:
Mor Zusman's avatar
Mor Zusman committed
101
102
        super().__init__()
        self.config = config
103
        self.is_lora_enabled = is_lora_enabled
104
105
106
107
108
109
110
111
112
113
        self.mamba = MambaMixer(hidden_size= config.hidden_size,
                                ssm_state_size = config.mamba_d_state,
                                conv_kernel_size = config.mamba_d_conv,
                                intermediate_size = config.mamba_expand *\
                                                    config.hidden_size,
                                time_step_rank = config.mamba_dt_rank,
                                use_conv_bias = config.mamba_conv_bias,
                                use_bias = config.mamba_proj_bias,
                                use_rms_norm=True,
                                rms_norm_eps=config.rms_norm_eps,
114
115
116
                                activation=config.hidden_act,
                                is_lora_enabled = self.is_lora_enabled
                                )
Mor Zusman's avatar
Mor Zusman committed
117
118

        num_experts = config.layers_num_experts[layer_idx]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if num_experts > 1:
            self.feed_forward = JambaMoE(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.feed_forward",
            )
        else:
            self.feed_forward = JambaMLP(
                config.hidden_size,
                config.intermediate_size,
                config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.feed_forward",
            )
Mor Zusman's avatar
Mor Zusman committed
133
134
135
136
137
138
139
140
141
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
142
        mamba_cache_params: MambaCacheParams,
Mor Zusman's avatar
Mor Zusman committed
143
144
145
146
147
148
149
150
151
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)

152
        hidden_states = self.mamba(hidden_states, mamba_cache_params)
Mor Zusman's avatar
Mor Zusman committed
153
154
155
156
157
158
159
160
161
        # Fully Connected
        hidden_states, residual = self.pre_ff_layernorm(
            hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


class JambaAttentionDecoderLayer(nn.Module):

162
163
164
165
166
167
168
    def __init__(self,
                 config: JambaConfig,
                 layer_idx: int,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "",
                 **kwargs) -> None:
Mor Zusman's avatar
Mor Zusman committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        config.hidden_size,
                                        bias=False,
                                        quant_config=quant_config)

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
209
            prefix=f"{prefix}.attn",
Mor Zusman's avatar
Mor Zusman committed
210
211
212
        )

        num_experts = config.layers_num_experts[layer_idx]
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        if num_experts > 1:
            self.feed_forward = JambaMoE(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.feed_forward",
            )
        else:
            self.feed_forward = JambaMLP(
                config.hidden_size,
                config.intermediate_size,
                config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.feed_forward",
            )
Mor Zusman's avatar
Mor Zusman committed
227
228
229
230
231
232
233
234
235
236
237
238
239
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)

    def self_attention(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
240
        attn_output = self.attn(q, k, v)
Mor Zusman's avatar
Mor Zusman committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        output, _ = self.o_proj(attn_output)
        return output

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)

        hidden_states = self.self_attention(
            positions=positions,
            hidden_states=hidden_states,
        )
        # Fully Connected
        hidden_states, residual = self.pre_ff_layernorm(
            hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
    "attention": JambaAttentionDecoderLayer,
    "mamba": JambaMambaDecoderLayer
}


class JambaModel(nn.Module):

277
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Mor Zusman's avatar
Mor Zusman committed
278
        super().__init__()
279
280
281
282
283
284

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Mor Zusman's avatar
Mor Zusman committed
285
286
287
288
289
290
291
292
293
294
295
296
        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

297
298
        extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}

299
300
301
302
        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
            layer_class = ALL_DECODER_LAYER_TYPES[
                config.layers_block_type[layer_idx]]
303
304
305
306
307
308
            return layer_class(config,
                               layer_idx,
                               cache_config,
                               quant_config=quant_config,
                               prefix=prefix,
                               **extra_kwargs)
309
310
311
312
313
314
315

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

Mor Zusman's avatar
Mor Zusman committed
316
317
318
        self.final_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)

319
320
321
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Mor Zusman's avatar
Mor Zusman committed
322
323
324
325
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
326
        mamba_cache_params: MambaCacheParams,
327
        intermediate_tensors: Optional[IntermediateTensors] = None,
328
        inputs_embeds: Optional[torch.Tensor] = None,
Mor Zusman's avatar
Mor Zusman committed
329
    ) -> torch.Tensor:
330
331
332
333
334
335
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
336
        else:
337
338
339
340
341
342
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        kv_cache_index = 0
        mamba_cache_index = 0
343
        for layer in self.layers[self.start_layer:self.end_layer]:
344
            layer_mamba_cache_params = None
Mor Zusman's avatar
Mor Zusman committed
345
            if isinstance(layer, JambaAttentionDecoderLayer):
346
                kv_cache_index += 1
Mor Zusman's avatar
Mor Zusman committed
347
            if isinstance(layer, JambaMambaDecoderLayer):
348
                current_state_layer = mamba_cache_index
349
350
                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
                    current_state_layer)
351
                mamba_cache_index += 1
Mor Zusman's avatar
Mor Zusman committed
352
353
354
355
356

            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
357
                mamba_cache_params=layer_mamba_cache_params)
358
359
360
361
362
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
Mor Zusman's avatar
Mor Zusman committed
363
364
365
        hidden_states, _ = self.final_layernorm(hidden_states, residual)
        return hidden_states

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if 'experts' in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for (
                        param_name,
                        weight_name,
                        expert_id,
                        shard_id,
                ) in expert_params_mapping:
                    if weight_name not in name:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  name,
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Mor Zusman's avatar
Mor Zusman committed
443

444
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
445
                       IsHybrid, SupportsV0Only):
446
447
448
449
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
        ".self_attn.": ".",
        ".A_log": ".A"
    }, )
Mor Zusman's avatar
Mor Zusman committed
450
451
452
453
454
455
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
456
        "gate_up_proj": ["gate_proj", "up_proj"],
457
        "in_proj": ["in_proj"],
Mor Zusman's avatar
Mor Zusman committed
458
459
460
461
462
463
464
465
466
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

467
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
468
469
470
471
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
472
473
474
        assert not cache_config.enable_prefix_caching, \
            "Jamba currently does not support prefix caching"

Mor Zusman's avatar
Mor Zusman committed
475
476
        super().__init__()
        self.config = config
477
478
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
479
        self.scheduler_config = scheduler_config
480
481
        self.model = JambaModel(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
Mor Zusman's avatar
Mor Zusman committed
482
483
484
485
486
487
488
489
490
491
492
493
494
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        # Used to track and store by the Mamba cache between steps.
495
496
        self.mamba_cache: Optional[MambaCacheManager] = None

Mor Zusman's avatar
Mor Zusman committed
497
498
499
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

500
501
502
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

503
504
505
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Mor Zusman's avatar
Mor Zusman committed
506
507
508
509
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
510
                inputs_embeds: Optional[torch.Tensor] = None,
Mor Zusman's avatar
Mor Zusman committed
511
                **kwargs):
512
        if self.mamba_cache is None:
513
514
            num_mamba_layers = self.model_config.get_num_layers_by_block_type(
                self.vllm_config.parallel_config, LayerBlockType.mamba)
515
            self.mamba_cache = MambaCacheManager(
516
517
                self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
                *self._get_mamba_cache_shape())
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
518
519
520

        mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

521
        hidden_states = self.model(input_ids, positions, mamba_cache_params,
522
                                   intermediate_tensors, inputs_embeds)
Mor Zusman's avatar
Mor Zusman committed
523
524
525
        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
526
527
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)
Mor Zusman's avatar
Mor Zusman committed
528
529

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
530
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
Mor Zusman's avatar
Mor Zusman committed
531
532

    def _get_mamba_cache_shape(
533
            self) -> tuple[tuple[int, int], tuple[int, int]]:
Mor Zusman's avatar
Mor Zusman committed
534
535
536
537
        world_size = get_tensor_model_parallel_world_size()
        hidden_size = self.config.hidden_size
        conv_state_shape = (
            self.config.mamba_expand * hidden_size // world_size,
538
            self.config.mamba_d_conv - 1,
Mor Zusman's avatar
Mor Zusman committed
539
540
        )
        temporal_state_shape = (
541
            self.config.mamba_expand * hidden_size // world_size,
Mor Zusman's avatar
Mor Zusman committed
542
543
544
545
            self.config.mamba_d_state,
        )
        return conv_state_shape, temporal_state_shape

546
547
548
549
550
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
551
        logits = self.logits_processor(self.lm_head, hidden_states,
Mor Zusman's avatar
Mor Zusman committed
552
553
554
                                       sampling_metadata)
        return logits

555
556
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
557
558
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
559

560
561
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()
562
563
564
565


class JambaForSequenceClassification(JambaForCausalLM):

566
567
    is_pooling_model = True

568
569
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
570

571
572
573
        config = vllm_config.model_config.hf_config
        num_labels: int = config.num_labels
        score_bias: bool = getattr(config, 'score_bias', False)
574
575
576
577
578
579
580
581
582
583

        # TODO: The original reward weights have float32 accuracy data, we
        # would like to load them in fp32 to get that extra precision.
        # Currently weight_loader passes the weight which is already in bf16
        self.score = nn.Linear(
            config.hidden_size,
            num_labels,
            bias=score_bias,
            dtype=torch.float32,
        )
584
585

        pooler_config = vllm_config.model_config.pooler_config
586
587
        assert pooler_config is not None

588
589
590
591
592
593
594
595
596
597
598
599
        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
            Pooler.for_classify(
                pooler_config,
                classifier=self.score,
                default_pooling_type=PoolingType.LAST,
                default_normalize=False,
                default_softmax=False,
            ),
        })