mamba.py 9.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11

import torch
from torch import nn
from transformers import MambaConfig

12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, ModelConfig, VllmConfig
14
from vllm.distributed.parallel_state import get_pp_group
15
16
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
18
from vllm.model_executor.layers.mamba.mamba_utils import (
19
20
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
21
22
23
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
from vllm.model_executor.layers.vocab_parallel_embedding import (
26
27
28
    ParallelLMHead,
    VocabParallelEmbedding,
)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
31
32
from vllm.model_executor.models.interfaces import (
    HasInnerState,
    IsAttentionFree,
33
    SupportsMambaPrefixCaching,
34
35
    SupportsPP,
)
36
37
from vllm.sequence import IntermediateTensors

38
39
40
41
42
43
44
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
45

46
KVCache = tuple[torch.Tensor, torch.Tensor]
47
48
49


class MambaDecoderLayer(nn.Module):
50
51
52
    def __init__(
        self,
        config: MambaConfig,
53
54
55
56
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        is_lora_enabled: bool | None = False,
57
58
        prefix: str = "",
    ) -> None:
59
60
        super().__init__()
        self.config = config
61
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
62
        self.is_lora_enabled = is_lora_enabled
63
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.mixer = MambaMixer(
            hidden_size=config.hidden_size,
            ssm_state_size=config.state_size,
            conv_kernel_size=config.conv_kernel,
            intermediate_size=config.intermediate_size,
            time_step_rank=config.time_step_rank,
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            use_rms_norm=self.is_falcon_mamba,
            rms_norm_has_weight=not self.is_falcon_mamba,
            rms_norm_eps=mixer_rms_eps,
            activation=config.hidden_act,
            is_lora_enabled=self.is_lora_enabled,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.mixer",
        )
81

82
83
84
85
86
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
87
        residual: torch.Tensor | None,
88
89
90
91
92
93
94
95
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

96
        output = torch.empty_like(hidden_states)
97
        self.mixer(hidden_states, output)
98
        return output, residual
99
100


101
@support_torch_compile
102
class MambaModel(nn.Module):
103
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
104
        super().__init__()
105
106

        config = vllm_config.model_config.hf_config
107
        model_config = vllm_config.model_config
108
109
110
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
111
        is_lora_enabled = bool(lora_config)
112

113
        self.config = config
114
115

        self.vocab_size = config.vocab_size
116
117
118
119
120
121

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

122
123
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            lambda prefix: MambaDecoderLayer(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                is_lora_enabled=is_lora_enabled,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )

        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
139

140
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
141
142
        return self.embeddings(input_ids)

143
144
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
145
        input_ids: torch.Tensor | None,
146
        positions: torch.Tensor,
147
148
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
149
    ) -> torch.Tensor:
150
151
152
153
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
154
                hidden_states = self.embed_input_ids(input_ids)
155
            residual = None
156
        else:
157
158
159
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
160

161
        for layer in islice(self.layers, self.start_layer, self.end_layer):
162
163
164
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
165
        if not get_pp_group().is_last_rank:
166
167
168
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
169
170
171
172
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

173
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
174
        params_dict = dict(self.named_parameters())
175
        loaded_params: set[str] = set()
176
177
178
179
180
181
182
183
184
185
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
186
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
187
188
189
190
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

191

192
193
194
class MambaForCausalLM(
    nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching
):
195
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
196
        config = vllm_config.model_config.hf_config
197

198
        self.scheduler_config = vllm_config.scheduler_config
199
200
201

        super().__init__()
        self.config = config
202
203
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
204
205
206
        self.backbone = MambaModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")
        )
207

208
209
210
211
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
212
                config.vocab_size,
213
                config.hidden_size,
214
                prefix=maybe_prefix(prefix, "lm_head"),
215
            )
216

217
        self.logits_processor = LogitsProcessor(config.vocab_size)
218

219
        self.make_empty_intermediate_tensors = (
220
221
            self.backbone.make_empty_intermediate_tensors
        )
222

223
224
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.embed_input_ids(input_ids)
225

226
227
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
228
        input_ids: torch.Tensor | None,
229
        positions: torch.Tensor,
230
231
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
232
233
234
235
236
        **kwargs,
    ):
        hidden_states = self.backbone(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
237
238
239

        return hidden_states

240
241
242
243
244
245
246
247
248
249
250
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba1_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

251
252
253
254
255
256
257
258
259
260
261
262
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.intermediate_size,
            state_size=hf_config.state_size,
263
264
            conv_kernel=hf_config.conv_kernel,
        )
265

266
267
268
269
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.mamba1_state_copy_func()

270
    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
271
        return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
272
273
274
275

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

276
277
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
278
279
        return logits

280
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
281
        loader = AutoWeightsLoader(self)
zhuwenwen's avatar
zhuwenwen committed
282
        return loader.load_weights(weights)