mamba.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""PyTorch MAMBA model."""
3
from typing import Iterable, List, Optional, Set, Tuple
4
5
6
7
8
9

import torch
from torch import nn
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
10
from vllm.config import CacheConfig, VllmConfig
11
from vllm.distributed import get_tensor_model_parallel_world_size
12
from vllm.distributed.parallel_state import get_pp_group
13
14
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
15
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
16
17
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Joe Runde's avatar
Joe Runde committed
18
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
19
from vllm.model_executor.layers.vocab_parallel_embedding import (
20
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.model_executor.models.interfaces import (HasInnerState,
23
                                                   IsAttentionFree, SupportsPP)
24
25
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
26
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
28
from vllm.utils import LayerBlockType
29

30
31
32
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
33

34
35
36
37
38
39
40
41
KVCache = Tuple[torch.Tensor, torch.Tensor]


class MambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
                 cache_config: Optional[CacheConfig] = None,
42
43
                 quant_config: Optional[QuantizationConfig] = None,
                 is_lora_enabled: Optional[bool] = False) -> None:
44
45
        super().__init__()
        self.config = config
46
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
47
        self.is_lora_enabled = is_lora_enabled
48
49
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
        self.mixer = MambaMixer(hidden_size=config.hidden_size,
50
51
52
53
54
55
56
                                ssm_state_size=config.state_size,
                                conv_kernel_size=config.conv_kernel,
                                intermediate_size=config.intermediate_size,
                                time_step_rank=config.time_step_rank,
                                use_conv_bias=config.use_conv_bias,
                                use_bias=config.use_bias,
                                use_rms_norm=self.is_falcon_mamba,
57
                                rms_norm_has_weight=not self.is_falcon_mamba,
58
                                rms_norm_eps=mixer_rms_eps,
59
60
                                activation=config.hidden_act,
                                is_lora_enabled=self.is_lora_enabled)
61

62
63
64
65
66
67
68
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
69
        mamba_cache_params: MambaCacheParams,
70
71
72
73
74
75
76
77
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

78
79
        hidden_states = self.mixer(hidden_states, attn_metadata,
                                   mamba_cache_params)
80
81
82
83
84
        return hidden_states, residual


class MambaModel(nn.Module):

85
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
86
        super().__init__()
87
88
89
90
91

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
92
        is_lora_enabled = bool(lora_config)
93

94
95
96
97
98
99
100
101
102
103
104
105
106
        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

107
108
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
109
110
111
112
            lambda prefix: MambaDecoderLayer(config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             is_lora_enabled=is_lora_enabled),
113
114
            prefix=f"{prefix}.layers")

115
116
        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
117
118
119
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
120

121
122
123
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

124
125
126
127
128
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        attn_metadata: AttentionMetadata,
129
        mamba_cache_params: MambaCacheParams,
130
        intermediate_tensors: Optional[IntermediateTensors] = None,
131
        inputs_embeds: Optional[torch.Tensor] = None,
132
    ) -> torch.Tensor:
133
134
135
136
137
138
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
139
        else:
140
141
142
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
143

144
        for i in range(self.start_layer, self.end_layer):
145
146
147
148
149
150
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                attn_metadata=attn_metadata,
                residual=residual,
151
152
153
154
155
156
157
                mamba_cache_params=mamba_cache_params.at_layer_idx(
                    i - self.start_layer))
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
158
159
160
161
162
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states


163
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
164

165
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
166
167
168
169
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
170
171
172
173
174
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
175
        self.vllm_config = vllm_config
176
        self.scheduler_config = scheduler_config
177
        self.model_config = vllm_config.model_config
178
179
        self.backbone = MambaModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "backbone"))
180
181
182
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
183
184
185
186
187
188
189
190
191
192
193
194
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
            )
195
196
197
198
199
200

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
201
        self.sampler = get_sampler()
202

203
204
        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)
205
206
207
208
209
210
211
212
213
214
215
        if self.scheduler_config is not None and \
            not self.model_config.enforce_eager:
            if self.scheduler_config.max_num_seqs > \
                vllm_config.compilation_config.max_capture_size:
                self.max_batch_size = \
                    vllm_config.compilation_config.max_capture_size
            else:
                self.max_batch_size = vllm_config.pad_for_cudagraph(
                    self.scheduler_config.max_num_seqs)
        else:
            self.max_batch_size = 8192 + 2
216

217
218
219
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

220
221
222
223
224
225
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[KVCache],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
226
                inputs_embeds: Optional[torch.Tensor] = None,
227
228
                **kwargs):
        if self.mamba_cache is None:
229
230
            num_mamba_layers = self.model_config.get_num_layers_by_block_type(
                self.vllm_config.parallel_config, LayerBlockType.mamba)
231
            self.mamba_cache = MambaCacheManager(
232
233
                self.lm_head.weight.dtype, num_mamba_layers,
                self.max_batch_size, *self._get_mamba_cache_shape())
234

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
235
        mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
236

237
        hidden_states = self.backbone(input_ids, positions, attn_metadata,
238
239
                                      mamba_cache_params, intermediate_tensors,
                                      inputs_embeds)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def _get_mamba_cache_shape(
            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
        world_size = get_tensor_model_parallel_world_size()
        conv_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.conv_kernel - 1,
        )
        temporal_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.state_size,
        )
        return conv_state_shape, temporal_state_shape

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

277
278
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
279
        params_dict = dict(self.named_parameters())
280
        loaded_params: Set[str] = set()
281
282
283
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
284
285
286
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
287
288
            if is_pp_missing_parameter(name, self):
                continue
289
290
291
292
293

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
294
295
            loaded_params.add(name)
        return loaded_params