jamba.py 21.2 KB
Newer Older
1
"""Inference-only Jamba model."""
2
from typing import Iterable, List, Optional, Tuple
Mor Zusman's avatar
Mor Zusman committed
3
4
5
6
7
8
9

import torch
from torch import nn
from transformers import JambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
10
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
11
from vllm.distributed import get_tensor_model_parallel_world_size
12
from vllm.model_executor.layers.fused_moe import FusedMoE
Mor Zusman's avatar
Mor Zusman committed
13
from vllm.model_executor.layers.layernorm import RMSNorm
14
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Mor Zusman's avatar
Mor Zusman committed
15
16
17
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
19
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
20
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Mor Zusman's avatar
Mor Zusman committed
21
22
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
25
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
Mor Zusman's avatar
Mor Zusman committed
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.sequence import IntermediateTensors
28
29
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
                                      _get_graph_batch_size)
Mor Zusman's avatar
Mor Zusman committed
30

31
from .interfaces import HasInnerState, SupportsLoRA
32

Mor Zusman's avatar
Mor Zusman committed
33
34
35
36
37
KVCache = Tuple[torch.Tensor, torch.Tensor]


class JambaMoE(nn.Module):

38
39
40
41
42
43
44
    def __init__(self,
                 config: JambaConfig,
                 num_experts: Optional[int] = None,
                 top_k: Optional[int] = None,
                 params_dtype: Optional[torch.dtype] = None,
                 tp_size: Optional[int] = None,
                 quant_config: Optional[QuantizationConfig] = None):
Mor Zusman's avatar
Mor Zusman committed
45
        super().__init__()
46
47
        self.num_total_experts = num_experts or config.num_experts
        self.top_k = top_k or config.num_experts_per_tok
Mor Zusman's avatar
Mor Zusman committed
48
        self.hidden_size = config.hidden_size
49
        self.intermediate_size = config.intermediate_size
Mor Zusman's avatar
Mor Zusman committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        if self.num_total_experts > 1:
            self.router = ReplicatedLinear(self.hidden_size,
                                           self.num_total_experts,
                                           bias=False,
                                           quant_config=None,
                                           params_dtype=params_dtype)

        self.experts = FusedMoE(self.num_total_experts,
                                self.top_k,
                                self.hidden_size,
                                self.intermediate_size,
                                tp_size=tp_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=False,
                                use_grouped_topk=False,
                                quant_config=quant_config)
Mor Zusman's avatar
Mor Zusman committed
68
69

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
70
        orig_shape = hidden_states.shape
Mor Zusman's avatar
Mor Zusman committed
71
72
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (batch * sequence_length, n_experts)
73
74
75
76
77
78
79
80
        if self.num_total_experts > 1:
            router_logits, _ = self.router(hidden_states)
        else:
            router_logits = torch.ones((hidden_states.shape[0], 1),
                                       device=hidden_states.device,
                                       dtype=hidden_states.dtype)
        hidden_states = self.experts(hidden_states, router_logits)
        return hidden_states.view(orig_shape)
Mor Zusman's avatar
Mor Zusman committed
81
82


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class JambaMLP(JambaMoE):

    def __init__(self,
                 config: JambaConfig,
                 params_dtype: Optional[torch.dtype] = None,
                 tp_size: Optional[int] = None,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__(config,
                         num_experts=1,
                         top_k=1,
                         params_dtype=params_dtype,
                         tp_size=tp_size,
                         quant_config=quant_config)


Mor Zusman's avatar
Mor Zusman committed
98
99
100
101
102
103
104
105
106
class JambaMambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: JambaConfig,
                 layer_idx: int,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()
        self.config = config
107
108
109
110
111
112
113
114
115
116
117
        self.mamba = MambaMixer(hidden_size= config.hidden_size,
                                ssm_state_size = config.mamba_d_state,
                                conv_kernel_size = config.mamba_d_conv,
                                intermediate_size = config.mamba_expand *\
                                                    config.hidden_size,
                                time_step_rank = config.mamba_dt_rank,
                                use_conv_bias = config.mamba_conv_bias,
                                use_bias = config.mamba_proj_bias,
                                use_rms_norm=True,
                                rms_norm_eps=config.rms_norm_eps,
                                activation=config.hidden_act)
Mor Zusman's avatar
Mor Zusman committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        num_experts = config.layers_num_experts[layer_idx]
        ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
        self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
132
        mamba_cache_params: MambaCacheParams,
Mor Zusman's avatar
Mor Zusman committed
133
134
135
136
137
138
139
140
141
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)

142
143
        hidden_states = self.mamba(hidden_states, attn_metadata,
                                   mamba_cache_params)
Mor Zusman's avatar
Mor Zusman committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        # Fully Connected
        hidden_states, residual = self.pre_ff_layernorm(
            hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


class JambaAttentionDecoderLayer(nn.Module):

    def __init__(
        self,
        config: JambaConfig,
        layer_idx: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        config.hidden_size,
                                        bias=False,
                                        quant_config=quant_config)

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
        )

        num_experts = config.layers_num_experts[layer_idx]
        ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
        self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)

    def self_attention(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)

        hidden_states = self.self_attention(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        # Fully Connected
        hidden_states, residual = self.pre_ff_layernorm(
            hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
    "attention": JambaAttentionDecoderLayer,
    "mamba": JambaMambaDecoderLayer
}


class JambaModel(nn.Module):

    def __init__(
        self,
        config: JambaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        decoder_layers = []
        for i in range(config.num_hidden_layers):
            layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
            decoder_layers.append(
                layer_class(config,
                            layer_idx=i,
                            cache_config=cache_config,
                            quant_config=quant_config))
        self.layers = nn.ModuleList(decoder_layers)
        self.final_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
300
        mamba_cache_params: MambaCacheParams,
Mor Zusman's avatar
Mor Zusman committed
301
302
303
304
305
306
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            kv_cache = None
307
            layer_mamba_cache_params = None
Mor Zusman's avatar
Mor Zusman committed
308
309
310
311
312
313
314
            if isinstance(layer, JambaAttentionDecoderLayer):
                kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
                                     self.config.attn_layer_period]
            if isinstance(layer, JambaMambaDecoderLayer):
                current_state_layer = i - (1 +
                                           (i - self.config.attn_layer_offset)
                                           // self.config.attn_layer_period)
315
316
                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
                    current_state_layer)
Mor Zusman's avatar
Mor Zusman committed
317
318
319
320
321
322
323

            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                kv_cache=kv_cache,
                attn_metadata=attn_metadata,
                residual=residual,
324
                mamba_cache_params=layer_mamba_cache_params)
Mor Zusman's avatar
Mor Zusman committed
325
326
327
328
        hidden_states, _ = self.final_layernorm(hidden_states, residual)
        return hidden_states


329
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
Mor Zusman's avatar
Mor Zusman committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(
        self,
353
354
        vllm_config: VllmConfig,
        prefix: str = "",
Mor Zusman's avatar
Mor Zusman committed
355
    ) -> None:
356
357
358
359
360
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
361
362
363
        assert not cache_config.enable_prefix_caching, \
            "Jamba currently does not support prefix caching"

Mor Zusman's avatar
Mor Zusman committed
364
365
        super().__init__()
        self.config = config
366
        self.scheduler_config = scheduler_config
Mor Zusman's avatar
Mor Zusman committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        self.model = JambaModel(config,
                                cache_config=cache_config,
                                quant_config=quant_config,
                                lora_config=lora_config)
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        # Used to track and store by the Mamba cache between steps.
384
385
        self.mamba_cache: Optional[MambaCacheManager] = None

Mor Zusman's avatar
Mor Zusman committed
386
387
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
388
        self.sampler = get_sampler()
Mor Zusman's avatar
Mor Zusman committed
389
390
391
392
393
394
395
396

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[KVCache],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs):
397
398
399
400
401
402
403
404
405
406
407
408
        if self.mamba_cache is None:
            max_batch_size = (_get_graph_batch_size(
                self.scheduler_config.max_num_seqs) if self.scheduler_config
                              else max(_BATCH_SIZES_TO_CAPTURE) + 2)

            layers_type = self.config.layers_block_type
            num_mamba_layers = sum(
                [layer_type == "mamba" for layer_type in layers_type])

            self.mamba_cache = MambaCacheManager(
                self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
                *self._get_mamba_cache_shape())
409
410
411
412
413
414
415
416
        (
            mamba_cache_tensors,
            state_indices_tensor,
        ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
                                                 **kwargs)
        mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
                                              mamba_cache_tensors[1],
                                              state_indices_tensor)
Mor Zusman's avatar
Mor Zusman committed
417
        hidden_states = self.model(input_ids, positions, kv_caches,
418
                                   attn_metadata, mamba_cache_params)
Mor Zusman's avatar
Mor Zusman committed
419
420
421
        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
422
423
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)
Mor Zusman's avatar
Mor Zusman committed
424
425

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
426
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
Mor Zusman's avatar
Mor Zusman committed
427
428

    def _get_mamba_cache_shape(
429
            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
Mor Zusman's avatar
Mor Zusman committed
430
431
432
433
        world_size = get_tensor_model_parallel_world_size()
        hidden_size = self.config.hidden_size
        conv_state_shape = (
            self.config.mamba_expand * hidden_size // world_size,
434
            self.config.mamba_d_conv - 1,
Mor Zusman's avatar
Mor Zusman committed
435
436
        )
        temporal_state_shape = (
437
            self.config.mamba_expand * hidden_size // world_size,
Mor Zusman's avatar
Mor Zusman committed
438
439
440
441
            self.config.mamba_d_state,
        )
        return conv_state_shape, temporal_state_shape

442
443
444
445
446
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
447
        logits = self.logits_processor(self.lm_head, hidden_states,
Mor Zusman's avatar
Mor Zusman committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

467
468
469
470
471
472
473
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
Mor Zusman's avatar
Mor Zusman committed
474
475
476
477
478
479
480
481
482
483
484
485

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if "A_log" in name:
                name = name.replace("A_log", "A")

            if ".self_attn." in name:
                name = name.replace(".self_attn", "")

486
487
488
489
            if "feed_forward" in name and not _is_moe_layer(name):
                ## map MLP layers to expert with ID=0
                name = name.replace("feed_forward", "feed_forward.experts.0")

Mor Zusman's avatar
Mor Zusman committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if 'experts' in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
504
505
506
507
508
509
                for (
                        param_name,
                        weight_name,
                        expert_id,
                        shard_id,
                ) in expert_params_mapping:
Mor Zusman's avatar
Mor Zusman committed
510
511
                    if weight_name not in name:
                        continue
512

Mor Zusman's avatar
Mor Zusman committed
513
514
515
516
517
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
518
                                  name,
519
                                  shard_id=shard_id,
Mor Zusman's avatar
Mor Zusman committed
520
521
522
523
524
525
526
527
528
529
530
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
531
532
533
534
535
536
537
538


def _is_moe_layer(name: str):
    return any(
        [experts_name in name for experts_name in [
            "experts",
            "router",
        ]])