mamba.py 10.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
from typing import Optional
8
9
10
11
12

import torch
from torch import nn
from transformers import MambaConfig

13
from vllm.compilation.decorators import support_torch_compile
14
from vllm.config import CacheConfig, ModelConfig, VllmConfig
15
from vllm.distributed.parallel_state import get_pp_group
16
17
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
19
from vllm.model_executor.layers.mamba.mamba_utils import (
20
21
22
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.layers.vocab_parallel_embedding import (
25
26
27
28
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
31
32
33
34
from vllm.model_executor.models.interfaces import (
    HasInnerState,
    IsAttentionFree,
    SupportsPP,
)
35
36
from vllm.sequence import IntermediateTensors

37
38
39
40
41
42
43
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
44

45
KVCache = tuple[torch.Tensor, torch.Tensor]
46
47
48


class MambaDecoderLayer(nn.Module):
49
50
51
52
53
54
55
56
57
    def __init__(
        self,
        config: MambaConfig,
        model_config: Optional[ModelConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        is_lora_enabled: Optional[bool] = False,
        prefix: str = "",
    ) -> None:
58
59
        super().__init__()
        self.config = config
60
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
61
        self.is_lora_enabled = is_lora_enabled
62
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.mixer = MambaMixer(
            hidden_size=config.hidden_size,
            ssm_state_size=config.state_size,
            conv_kernel_size=config.conv_kernel,
            intermediate_size=config.intermediate_size,
            time_step_rank=config.time_step_rank,
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            use_rms_norm=self.is_falcon_mamba,
            rms_norm_has_weight=not self.is_falcon_mamba,
            rms_norm_eps=mixer_rms_eps,
            activation=config.hidden_act,
            is_lora_enabled=self.is_lora_enabled,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.mixer",
        )
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

95
        output = torch.empty_like(hidden_states)
96
        self.mixer(hidden_states, output)
97
        return output, residual
98
99


100
@support_torch_compile
101
class MambaModel(nn.Module):
102
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
103
        super().__init__()
104
105

        config = vllm_config.model_config.hf_config
106
        model_config = vllm_config.model_config
107
108
109
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
110
        is_lora_enabled = bool(lora_config)
111

112
        self.config = config
113
114
115
116
117
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
118
119
120
121
122
123
124
125
126
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

127
128
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            lambda prefix: MambaDecoderLayer(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                is_lora_enabled=is_lora_enabled,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )

        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
144

145
146
147
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

148
149
150
151
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
152
        intermediate_tensors: Optional[IntermediateTensors] = None,
153
        inputs_embeds: Optional[torch.Tensor] = None,
154
    ) -> torch.Tensor:
155
156
157
158
159
160
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
161
        else:
162
163
164
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
165

166
        for layer in islice(self.layers, self.start_layer, self.end_layer):
167
168
169
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
170
        if not get_pp_group().is_last_rank:
171
172
173
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
174
175
176
177
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

178
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
179
        params_dict = dict(self.named_parameters())
180
        loaded_params: set[str] = set()
181
182
183
184
185
186
187
188
189
190
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
191
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
192
193
194
195
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

196

197
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
198
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199
200
201
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
202
        self.scheduler_config = vllm_config.scheduler_config
203
        assert not cache_config.enable_prefix_caching, (
204
            "Mamba does not support prefix caching"
205
        )
206
207
208

        super().__init__()
        self.config = config
209
210
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
211
212
213
        self.backbone = MambaModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")
        )
214
215
216
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
217
218
219
220
221
222
223
224
225
226
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
227
228
                if not lora_config
                else lora_config.lora_vocab_padding_size,
229
                prefix=maybe_prefix(prefix, "lm_head"),
230
            )
231

232
233
234
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
235

236
        self.make_empty_intermediate_tensors = (
237
238
            self.backbone.make_empty_intermediate_tensors
        )
239

240
241
242
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

243
244
245
246
247
248
249
250
251
252
253
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        hidden_states = self.backbone(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
254
255
256

        return hidden_states

257
258
259
260
261
262
263
264
265
266
267
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba1_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

268
269
270
271
272
273
274
275
276
277
278
279
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.intermediate_size,
            state_size=hf_config.state_size,
280
281
            conv_kernel=hf_config.conv_kernel,
        )
282

283
    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
284
        return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
285
286
287
288

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

289
290
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
291
292
        return logits

293
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
294
295
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)