mamba.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA model."""
4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
from torch import nn
from transformers import MambaConfig

11
from vllm import envs
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed.parallel_state import get_pp_group
14
15
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
17
18
from vllm.model_executor.layers.mamba.mamba_utils import (
    MambaStateShapeCalculator)
19
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
22
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
from vllm.model_executor.models.interfaces import (HasInnerState,
25
                                                   IsAttentionFree, SupportsPP)
26
27
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
28
29
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
30
from vllm.utils import LayerBlockType
31

32
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
33
34
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
35

36
KVCache = tuple[torch.Tensor, torch.Tensor]
37
38
39
40
41
42
43


class MambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
                 cache_config: Optional[CacheConfig] = None,
44
                 quant_config: Optional[QuantizationConfig] = None,
45
46
                 is_lora_enabled: Optional[bool] = False,
                 prefix: str = "") -> None:
47
48
        super().__init__()
        self.config = config
49
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
50
        self.is_lora_enabled = is_lora_enabled
51
52
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
        self.mixer = MambaMixer(hidden_size=config.hidden_size,
53
54
55
56
57
58
59
                                ssm_state_size=config.state_size,
                                conv_kernel_size=config.conv_kernel,
                                intermediate_size=config.intermediate_size,
                                time_step_rank=config.time_step_rank,
                                use_conv_bias=config.use_conv_bias,
                                use_bias=config.use_bias,
                                use_rms_norm=self.is_falcon_mamba,
60
                                rms_norm_has_weight=not self.is_falcon_mamba,
61
                                rms_norm_eps=mixer_rms_eps,
62
                                activation=config.hidden_act,
63
64
                                is_lora_enabled=self.is_lora_enabled,
                                prefix=f"{prefix}.mixer")
65

66
67
68
69
70
71
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
72
        mamba_cache_params: MambaCacheParams,
73
74
75
76
77
78
79
80
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

81
        hidden_states = self.mixer(hidden_states, mamba_cache_params)
82
83
84
85
86
        return hidden_states, residual


class MambaModel(nn.Module):

87
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
88
        super().__init__()
89
90
91
92
93

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
94
        is_lora_enabled = bool(lora_config)
95

96
97
98
99
100
101
102
103
104
105
106
107
        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

108
109
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
110
111
112
            lambda prefix: MambaDecoderLayer(config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
113
114
                                             is_lora_enabled=is_lora_enabled,
                                             prefix=prefix),
115
116
            prefix=f"{prefix}.layers")

117
118
        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
119
120
121
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
122

123
124
125
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

126
127
128
129
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
130
        mamba_cache_params: Optional[MambaCacheParams] = None,
131
        intermediate_tensors: Optional[IntermediateTensors] = None,
132
        inputs_embeds: Optional[torch.Tensor] = None,
133
    ) -> torch.Tensor:
134
135
136
137
138
139
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
140
        else:
141
142
143
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
144

145
        for i in range(self.start_layer, self.end_layer):
146
            layer = self.layers[i]
147
148
149
150
151
152

            layer_cache_params = None
            if mamba_cache_params is not None:
                layer_cache_params = mamba_cache_params.at_layer_idx(
                    i - self.start_layer)

153
154
155
156
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
157
                mamba_cache_params=layer_cache_params)
158
159
160
161
162
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
163
164
165
166
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

167
168
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
169
        params_dict = dict(self.named_parameters())
170
        loaded_params: set[str] = set()
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

187

188
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
189

190
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
191
192
193
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
194
        self.scheduler_config = vllm_config.scheduler_config
195
196
197
198
199
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
200
201
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
202
203
        self.backbone = MambaModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "backbone"))
204
205
206
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
207
208
209
210
211
212
213
214
215
216
217
218
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
            )
219
220
221
222
223
224
225

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

226
227
228
        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)

229
230
231
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

232
233
234
235
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
236
                inputs_embeds: Optional[torch.Tensor] = None,
237
238
                **kwargs):

239
240
241
242
243
244
245
246
247
248
249
250
        mamba_cache_params = None
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:
                num_layers = self.model_config.get_num_layers_by_block_type(
                    self.vllm_config.parallel_config, LayerBlockType.mamba)
                state_shape = self.get_mamba_state_shape_from_config(
                    self.vllm_config)
                self.mamba_cache = MambaCacheManager(self.vllm_config,
                                                     self.lm_head.weight.dtype,
                                                     num_layers, *state_shape)

            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
251

252
253
        hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                      intermediate_tensors, inputs_embeds)
254
255
256

        return hidden_states

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.intermediate_size,
            state_size=hf_config.state_size,
            conv_kernel=hf_config.conv_kernel,
            use_v1=envs.VLLM_USE_V1)

272
273
274
275
276
277
278
279
280
281
282
283
284
    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

285
286
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
287
288
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)