mamba2.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA2 model."""
4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
from torch import nn
from transformers import MambaConfig

Chen Zhang's avatar
Chen Zhang committed
11
from vllm import envs
12
from vllm.attention.backends.abstract import AttentionMetadata
13
from vllm.compilation.decorators import support_torch_compile
14
from vllm.config import CacheConfig, ModelConfig, VllmConfig
15
from vllm.distributed.parallel_state import get_pp_group
16
from vllm.forward_context import get_forward_context
17
18
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
19
20
from vllm.model_executor.layers.mamba.mamba2_metadata import (
    Mamba2Metadata, prepare_mamba2_metadata)
21
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
22
from vllm.model_executor.layers.mamba.mamba_utils import (
23
    MambaStateDtypeCalculator, MambaStateShapeCalculator)
24
25
26
27
28
29
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
Chen Zhang's avatar
Chen Zhang committed
30
                                                   IsAttentionFree)
31
32
33
34
35
36
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType

37
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
38
39
40
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)

41
KVCache = tuple[torch.Tensor, torch.Tensor]
42
43
44
45
46
47


class Mamba2DecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
48
49
                 model_config: Optional[ModelConfig] = None,
                 cache_config: Optional[CacheConfig] = None,
Chen Zhang's avatar
Chen Zhang committed
50
51
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        super().__init__()
        self.config = config
        self.mixer = MambaMixer2(hidden_size=config.hidden_size,
                                 ssm_state_size=config.state_size,
                                 conv_kernel_size=config.conv_kernel,
                                 intermediate_size=getattr(
                                     config, "intermediate_size",
                                     config.expand * config.hidden_size),
                                 use_conv_bias=config.use_conv_bias,
                                 use_bias=config.use_bias,
                                 n_groups=config.n_groups,
                                 num_heads=config.num_heads,
                                 head_dim=config.head_dim,
                                 rms_norm_eps=config.layer_norm_epsilon,
                                 activation=config.hidden_act,
67
68
                                 model_config=model_config,
                                 cache_config=cache_config,
Chen Zhang's avatar
Chen Zhang committed
69
                                 quant_config=quant_config,
70
                                 prefix=f"{prefix}.mixer")
71
72
73
74
75
76
77
78

        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        mamba_cache_params: MambaCacheParams,
79
        mamba2_metadata: Mamba2Metadata,
80
81
82
83
84
85
86
87
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

88
89
90
        output = torch.empty_like(hidden_states)
        self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
        return output, residual
91
92


93
@support_torch_compile
94
95
96
97
98
99
class Mamba2Model(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
100
101
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        is_lora_enabled = bool(lora_config)
        assert not is_lora_enabled

        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
121
122
123
124
125
            lambda prefix: Mamba2DecoderLayer(config,
                                              model_config=model_config,
                                              cache_config=cache_config,
                                              quant_config=quant_config,
                                              prefix=prefix),
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            prefix=f"{prefix}.layers")

        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

156
        attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
157

Chen Zhang's avatar
Chen Zhang committed
158
159
160
161
162
163
164
165
        if not envs.VLLM_USE_V1:
            mamba2_metadata = prepare_mamba2_metadata(
                chunk_size=self.config.chunk_size,
                attn_metadata=attn_metadata,
            )
        else:
            # v1 get mamba2_metadata from forward_context
            mamba2_metadata = None
166

167
        for i, layer in enumerate(self.layers):
168
169
170
171
172
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
                mamba_cache_params=mamba_cache_params.at_layer_idx(
Chen Zhang's avatar
Chen Zhang committed
173
                    i - self.start_layer) if mamba_cache_params else None,
174
                mamba2_metadata=mamba2_metadata)
175
176
177
178
179
180
181
182
183
184
185

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")

            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

207

Chen Zhang's avatar
Chen Zhang committed
208
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
209

210
211
212
213
214
215
216
217
218
219
220
221
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:

        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config
            use_v1: Get shapes for V1 (or V0)

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.expand * hf_config.hidden_size

243
        return MambaStateShapeCalculator.mamba2_state_shape(
244
245
246
247
248
249
250
251
252
253
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.num_heads,
            head_dim=hf_config.head_dim,
            state_size=hf_config.state_size,
            conv_kernel=hf_config.conv_kernel,
            use_v1=use_v1,
        )

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
        self.vllm_config = vllm_config
        self.scheduler_config = scheduler_config
        self.model_config = vllm_config.model_config
        self.backbone = Mamba2Model(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "backbone"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
281
            prefix=maybe_prefix(prefix, "lm_head"),
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs):
Chen Zhang's avatar
Chen Zhang committed
304
305
306
307
308
309
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:
                num_mamba_layers = (
                    self.model_config.get_num_layers_by_block_type(
                        self.vllm_config.parallel_config,
                        LayerBlockType.mamba))
310
311
312
                mamba_state_shape = \
                    self.get_mamba_state_shape_from_config(
                        self.vllm_config, use_v1=False)
313
314
315
                mamba_state_dtype = \
                    self.get_mamba_state_dtype_from_config(
                    self.vllm_config)
316
317
                self.mamba_cache = MambaCacheManager(self.vllm_config,
                                                     num_mamba_layers,
318
319
                                                     *mamba_state_shape,
                                                     *mamba_state_dtype)
Chen Zhang's avatar
Chen Zhang committed
320
321
322
323
324

            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
        else:
            # NOTE: mamba_cache_params is not needed for v1
            mamba_cache_params = None
325

326
327
        hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                      intermediate_tensors, inputs_embeds)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

344
345
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
346
347
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)