mamba.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""PyTorch MAMBA model."""
3
from typing import Iterable, Optional, Set, Tuple
4
5
6
7
8

import torch
from torch import nn
from transformers import MambaConfig

9
from vllm.config import CacheConfig, VllmConfig
10
from vllm.distributed import get_tensor_model_parallel_world_size
11
from vllm.distributed.parallel_state import get_pp_group
12
13
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
15
16
17
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
from vllm.model_executor.models.interfaces import (HasInnerState,
21
22
                                                   IsAttentionFree, SupportsPP,
                                                   SupportsV0Only)
23
24
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
25
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
27
from vllm.utils import LayerBlockType
28

29
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
30
31
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
32

33
34
35
36
37
38
39
40
KVCache = Tuple[torch.Tensor, torch.Tensor]


class MambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
                 cache_config: Optional[CacheConfig] = None,
41
42
                 quant_config: Optional[QuantizationConfig] = None,
                 is_lora_enabled: Optional[bool] = False) -> None:
43
44
        super().__init__()
        self.config = config
45
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
46
        self.is_lora_enabled = is_lora_enabled
47
48
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
        self.mixer = MambaMixer(hidden_size=config.hidden_size,
49
50
51
52
53
54
55
                                ssm_state_size=config.state_size,
                                conv_kernel_size=config.conv_kernel,
                                intermediate_size=config.intermediate_size,
                                time_step_rank=config.time_step_rank,
                                use_conv_bias=config.use_conv_bias,
                                use_bias=config.use_bias,
                                use_rms_norm=self.is_falcon_mamba,
56
                                rms_norm_has_weight=not self.is_falcon_mamba,
57
                                rms_norm_eps=mixer_rms_eps,
58
59
                                activation=config.hidden_act,
                                is_lora_enabled=self.is_lora_enabled)
60

61
62
63
64
65
66
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
67
        mamba_cache_params: MambaCacheParams,
68
69
70
71
72
73
74
75
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

76
        hidden_states = self.mixer(hidden_states, mamba_cache_params)
77
78
79
80
81
        return hidden_states, residual


class MambaModel(nn.Module):

82
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
83
        super().__init__()
84
85
86
87
88

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
89
        is_lora_enabled = bool(lora_config)
90

91
92
93
94
95
96
97
98
99
100
101
102
        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

103
104
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
105
106
107
108
            lambda prefix: MambaDecoderLayer(config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             is_lora_enabled=is_lora_enabled),
109
110
            prefix=f"{prefix}.layers")

111
112
        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
113
114
115
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
116

117
118
119
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

120
121
122
123
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
124
        mamba_cache_params: MambaCacheParams,
125
        intermediate_tensors: Optional[IntermediateTensors] = None,
126
        inputs_embeds: Optional[torch.Tensor] = None,
127
    ) -> torch.Tensor:
128
129
130
131
132
133
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
134
        else:
135
136
137
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
138

139
        for i in range(self.start_layer, self.end_layer):
140
141
142
143
144
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
145
146
147
148
149
150
151
                mamba_cache_params=mamba_cache_params.at_layer_idx(
                    i - self.start_layer))
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
152
153
154
155
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

176

177
178
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
                       SupportsV0Only):
179

180
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
181
182
183
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
184
        self.scheduler_config = vllm_config.scheduler_config
185
186
187
188
189
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
190
191
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
192
193
        self.backbone = MambaModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "backbone"))
194
195
196
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
197
198
199
200
201
202
203
204
205
206
207
208
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
            )
209
210
211
212
213
214
215

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

216
217
218
        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)

219
220
221
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

222
223
224
225
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
226
                inputs_embeds: Optional[torch.Tensor] = None,
227
228
                **kwargs):
        if self.mamba_cache is None:
229
230
            num_mamba_layers = self.model_config.get_num_layers_by_block_type(
                self.vllm_config.parallel_config, LayerBlockType.mamba)
231
            self.mamba_cache = MambaCacheManager(
232
233
                self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
                *self._get_mamba_cache_shape())
234

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
235
        mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
236

237
238
        hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                      intermediate_tensors, inputs_embeds)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def _get_mamba_cache_shape(
            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
        world_size = get_tensor_model_parallel_world_size()
        conv_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.conv_kernel - 1,
        )
        temporal_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.state_size,
        )
        return conv_state_shape, temporal_state_shape

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

268
269
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
270
271
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)