mamba.py 9.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11

import torch
from torch import nn
from transformers import MambaConfig

12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, ModelConfig, VllmConfig
14
from vllm.distributed.parallel_state import get_pp_group
15
16
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
18
from vllm.model_executor.layers.mamba.mamba_utils import (
19
20
21
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
29
30
from vllm.model_executor.models.interfaces import (
    HasInnerState,
    IsAttentionFree,
31
    SupportsMambaPrefixCaching,
32
33
    SupportsPP,
)
34
35
from vllm.sequence import IntermediateTensors

36
37
38
39
40
41
42
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
43

44
KVCache = tuple[torch.Tensor, torch.Tensor]
45
46
47


class MambaDecoderLayer(nn.Module):
48
49
50
    def __init__(
        self,
        config: MambaConfig,
51
52
53
54
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        is_lora_enabled: bool | None = False,
55
56
        prefix: str = "",
    ) -> None:
57
58
        super().__init__()
        self.config = config
59
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
60
        self.is_lora_enabled = is_lora_enabled
61
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        self.mixer = MambaMixer(
            hidden_size=config.hidden_size,
            ssm_state_size=config.state_size,
            conv_kernel_size=config.conv_kernel,
            intermediate_size=config.intermediate_size,
            time_step_rank=config.time_step_rank,
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            use_rms_norm=self.is_falcon_mamba,
            rms_norm_has_weight=not self.is_falcon_mamba,
            rms_norm_eps=mixer_rms_eps,
            activation=config.hidden_act,
            is_lora_enabled=self.is_lora_enabled,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.mixer",
        )
79

80
81
82
83
84
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
85
        residual: torch.Tensor | None,
86
87
88
89
90
91
92
93
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

94
        output = torch.empty_like(hidden_states)
95
        self.mixer(hidden_states, output)
96
        return output, residual
97
98


99
@support_torch_compile
100
class MambaModel(nn.Module):
101
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
102
        super().__init__()
103
104

        config = vllm_config.model_config.hf_config
105
        model_config = vllm_config.model_config
106
107
108
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
109
        is_lora_enabled = bool(lora_config)
110

111
        self.config = config
112
113

        self.vocab_size = config.vocab_size
114
115
116
117
118
119

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

120
121
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            lambda prefix: MambaDecoderLayer(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                is_lora_enabled=is_lora_enabled,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )

        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
137

138
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
139
140
        return self.embeddings(input_ids)

141
142
143
144
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
145
146
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
147
    ) -> torch.Tensor:
148
149
150
151
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
152
                hidden_states = self.embed_input_ids(input_ids)
153
            residual = None
154
        else:
155
156
157
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
158

159
        for layer in islice(self.layers, self.start_layer, self.end_layer):
160
161
162
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
163
        if not get_pp_group().is_last_rank:
164
165
166
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
167
168
169
170
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

171
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
172
        params_dict = dict(self.named_parameters())
173
        loaded_params: set[str] = set()
174
175
176
177
178
179
180
181
182
183
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
184
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
185
186
187
188
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

189

190
191
192
class MambaForCausalLM(
    nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching
):
193
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194
        config = vllm_config.model_config.hf_config
195

196
        self.scheduler_config = vllm_config.scheduler_config
197
198
199

        super().__init__()
        self.config = config
200
201
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
202
203
204
        self.backbone = MambaModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")
        )
205

206
207
208
209
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
210
                config.vocab_size,
211
                config.hidden_size,
212
                prefix=maybe_prefix(prefix, "lm_head"),
213
            )
214

215
        self.logits_processor = LogitsProcessor(config.vocab_size)
216

217
        self.make_empty_intermediate_tensors = (
218
219
            self.backbone.make_empty_intermediate_tensors
        )
220

221
222
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.embed_input_ids(input_ids)
223

224
225
226
227
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
228
229
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
230
231
232
233
234
        **kwargs,
    ):
        hidden_states = self.backbone(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
235
236
237

        return hidden_states

238
239
240
241
242
243
244
245
246
247
248
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba1_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

249
250
251
252
253
254
255
256
257
258
259
260
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.intermediate_size,
            state_size=hf_config.state_size,
261
262
            conv_kernel=hf_config.conv_kernel,
        )
263

264
    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
265
        return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
266
267
268
269

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

270
271
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
272
273
        return logits

274
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
275
276
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)