mamba.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA model."""
4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
from torch import nn
from transformers import MambaConfig

11
from vllm import envs
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, ModelConfig, VllmConfig
14
from vllm.distributed.parallel_state import get_pp_group
15
16
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
18
from vllm.model_executor.layers.mamba.mamba_utils import (
19
    MambaStateDtypeCalculator, MambaStateShapeCalculator)
20
21
22
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
23
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
from vllm.model_executor.models.interfaces import (HasInnerState,
26
                                                   IsAttentionFree, SupportsPP)
27
28
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils import LayerBlockType
31

32
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
33
34
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
35

36
KVCache = tuple[torch.Tensor, torch.Tensor]
37
38
39
40
41
42


class MambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
43
                 model_config: Optional[ModelConfig] = None,
44
                 cache_config: Optional[CacheConfig] = None,
45
                 quant_config: Optional[QuantizationConfig] = None,
46
47
                 is_lora_enabled: Optional[bool] = False,
                 prefix: str = "") -> None:
48
49
        super().__init__()
        self.config = config
50
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
51
        self.is_lora_enabled = is_lora_enabled
52
53
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
        self.mixer = MambaMixer(hidden_size=config.hidden_size,
54
55
56
57
58
59
60
                                ssm_state_size=config.state_size,
                                conv_kernel_size=config.conv_kernel,
                                intermediate_size=config.intermediate_size,
                                time_step_rank=config.time_step_rank,
                                use_conv_bias=config.use_conv_bias,
                                use_bias=config.use_bias,
                                use_rms_norm=self.is_falcon_mamba,
61
                                rms_norm_has_weight=not self.is_falcon_mamba,
62
                                rms_norm_eps=mixer_rms_eps,
63
                                activation=config.hidden_act,
64
                                is_lora_enabled=self.is_lora_enabled,
65
66
                                model_config=model_config,
                                cache_config=cache_config,
67
                                prefix=f"{prefix}.mixer")
68

69
70
71
72
73
74
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
75
        mamba_cache_params: MambaCacheParams,
76
77
78
79
80
81
82
83
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

84
85
86
        output = torch.empty_like(hidden_states)
        self.mixer(hidden_states, output, mamba_cache_params)
        return output, residual
87
88


89
@support_torch_compile
90
91
class MambaModel(nn.Module):

92
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
93
        super().__init__()
94
95

        config = vllm_config.model_config.hf_config
96
        model_config = vllm_config.model_config
97
98
99
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
100
        is_lora_enabled = bool(lora_config)
101

102
103
104
105
106
107
108
109
110
111
112
113
        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

114
115
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
116
            lambda prefix: MambaDecoderLayer(config,
117
                                             model_config=model_config,
118
119
                                             cache_config=cache_config,
                                             quant_config=quant_config,
120
121
                                             is_lora_enabled=is_lora_enabled,
                                             prefix=prefix),
122
123
            prefix=f"{prefix}.layers")

124
125
        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
126
127
128
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
129

130
131
132
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

133
134
135
136
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
137
        mamba_cache_params: Optional[MambaCacheParams] = None,
138
        intermediate_tensors: Optional[IntermediateTensors] = None,
139
        inputs_embeds: Optional[torch.Tensor] = None,
140
    ) -> torch.Tensor:
141
142
143
144
145
146
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
147
        else:
148
149
150
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
151

152
        for i in range(self.start_layer, self.end_layer):
153
            layer = self.layers[i]
154
155
156
157
158
159

            layer_cache_params = None
            if mamba_cache_params is not None:
                layer_cache_params = mamba_cache_params.at_layer_idx(
                    i - self.start_layer)

160
161
162
163
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
164
                mamba_cache_params=layer_cache_params)
165
166
167
168
169
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
170
171
172
173
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

174
175
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
176
        params_dict = dict(self.named_parameters())
177
        loaded_params: set[str] = set()
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

194

195
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
196

197
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198
199
200
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
201
        self.scheduler_config = vllm_config.scheduler_config
202
203
204
205
206
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
207
208
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
209
210
        self.backbone = MambaModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "backbone"))
211
212
213
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
214
215
216
217
218
219
220
221
222
223
224
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
225
                prefix=maybe_prefix(prefix, "lm_head"),
226
            )
227
228
229
230
231
232
233

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

234
235
236
        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)

237
238
239
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

240
241
242
243
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
244
                inputs_embeds: Optional[torch.Tensor] = None,
245
246
                **kwargs):

247
248
249
250
251
252
253
        mamba_cache_params = None
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:
                num_layers = self.model_config.get_num_layers_by_block_type(
                    self.vllm_config.parallel_config, LayerBlockType.mamba)
                state_shape = self.get_mamba_state_shape_from_config(
                    self.vllm_config)
254
255
                state_dtype = self.get_mamba_state_dtype_from_config(
                    self.vllm_config)
256
                self.mamba_cache = MambaCacheManager(self.vllm_config,
257
258
                                                     num_layers, *state_shape,
                                                     *state_dtype)
259
260

            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
261

262
263
        hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                      intermediate_tensors, inputs_embeds)
264
265
266

        return hidden_states

267
268
269
270
271
272
273
274
275
276
277
278
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:

        return MambaStateDtypeCalculator.mamba1_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.intermediate_size,
            state_size=hf_config.state_size,
            conv_kernel=hf_config.conv_kernel,
            use_v1=envs.VLLM_USE_V1)

294
295
296
297
298
299
300
    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

301
302
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
303
304
        return logits

305
306
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
307
308
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)