"vscode:/vscode.git/clone" did not exist on "c0a7b89d8e7e08af8c75139c8af6e105ac20112f"
mamba2.py 10.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""PyTorch MAMBA2 model."""
4

5
from collections.abc import Iterable
6
7
8
9
10

import torch
from torch import nn
from transformers import MambaConfig

11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, ModelConfig, VllmConfig
13
14
15
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
17
from vllm.model_executor.layers.mamba.mamba_utils import (
18
19
20
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.vocab_parallel_embedding import (
23
24
25
26
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree
29
30
from vllm.sequence import IntermediateTensors

31
32
33
34
35
36
37
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
38

39
KVCache = tuple[torch.Tensor, torch.Tensor]
40
41
42


class Mamba2DecoderLayer(nn.Module):
43
44
45
    def __init__(
        self,
        config: MambaConfig,
46
47
48
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
49
50
        prefix: str = "",
    ) -> None:
51
52
        super().__init__()
        self.config = config
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.mixer = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.state_size,
            conv_kernel_size=config.conv_kernel,
            intermediate_size=getattr(
                config, "intermediate_size", config.expand * config.hidden_size
            ),
            use_conv_bias=config.use_conv_bias,
            use_bias=config.use_bias,
            n_groups=config.n_groups,
            num_heads=config.num_heads,
            head_dim=config.head_dim,
            rms_norm_eps=config.layer_norm_epsilon,
            activation=config.hidden_act,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.mixer",
        )
72
73
74
75
76
77

        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
78
        residual: torch.Tensor | None,
79
80
81
82
83
84
85
86
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

87
        output = torch.empty_like(hidden_states)
88
        self.mixer(hidden_states, output)
89
        return output, residual
90
91


92
@support_torch_compile
93
94
95
96
97
class Mamba2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
98
99
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
100
101
102
103
104
105
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        is_lora_enabled = bool(lora_config)
        assert not is_lora_enabled

        self.config = config
106
107
108
109
110
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
111
112
113
114
115
116
117
118
119
120
121
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            lambda prefix: Mamba2DecoderLayer(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )

        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
136
137
138
139
140
141
142
143

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
144
145
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
146
147
148
149
150
151
152
153
154
155
156
157
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

158
        for i, layer in enumerate(self.layers):
159
160
161
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
162
163

        if not get_pp_group().is_last_rank:
164
165
166
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
167
168
169
170
171

        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

172
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
173
174
175
176
177
178
179
180
181
182
183
184
185
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")

            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
186
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
187
188
189
190
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

191

Chen Zhang's avatar
Chen Zhang committed
192
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
193
194
195
196
197
198
199
200
201
202
203
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.expand * hf_config.hidden_size

223
        return MambaStateShapeCalculator.mamba2_state_shape(
224
225
226
227
228
229
230
231
232
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.num_heads,
            head_dim=hf_config.head_dim,
            state_size=hf_config.state_size,
            conv_kernel=hf_config.conv_kernel,
        )

233
234
235
236
237
238
239
240
241
242
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config

        super().__init__()
        self.config = config
        self.vllm_config = vllm_config
        self.scheduler_config = scheduler_config
        self.model_config = vllm_config.model_config
243
244
245
        self.backbone = Mamba2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")
        )
246
247
248
249
250
251
252
253
254
255
256
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
257
258
            if not lora_config
            else lora_config.lora_vocab_padding_size,
259
            prefix=maybe_prefix(prefix, "lm_head"),
260
261
262
263
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)

264
265
266
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
267
268

        self.make_empty_intermediate_tensors = (
269
270
            self.backbone.make_empty_intermediate_tensors
        )
271
272
273
274

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

275
276
277
278
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
279
280
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
281
282
283
284
285
        **kwargs,
    ):
        hidden_states = self.backbone(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
286
287
288
289

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
290
        return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
291
292
293
294

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

295
296
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
297
298
        return logits

299
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
300
301
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)