mamba.py 12.2 KB
Newer Older
1
"""PyTorch MAMBA model."""
2
from typing import Iterable, List, Optional, Set, Tuple
3
4
5
6
7
8

import torch
from torch import nn
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
9
from vllm.config import CacheConfig, VllmConfig
10
from vllm.distributed import get_tensor_model_parallel_world_size
11
from vllm.distributed.parallel_state import get_pp_group
12
13
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
from vllm.model_executor.layers.vocab_parallel_embedding import (
19
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
20
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
from vllm.model_executor.models.interfaces import (HasInnerState,
22
                                                   IsAttentionFree, SupportsPP)
23
24
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
25
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
27
from vllm.utils import LayerBlockType
28

29
30
31
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
32

33
34
35
36
37
38
39
40
41
42
43
KVCache = Tuple[torch.Tensor, torch.Tensor]


class MambaDecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()
        self.config = config
44
        self.is_falcon_mamba = config.model_type == "falcon_mamba"
45
46
        mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
        self.mixer = MambaMixer(hidden_size=config.hidden_size,
47
48
49
50
51
52
53
                                ssm_state_size=config.state_size,
                                conv_kernel_size=config.conv_kernel,
                                intermediate_size=config.intermediate_size,
                                time_step_rank=config.time_step_rank,
                                use_conv_bias=config.use_conv_bias,
                                use_bias=config.use_bias,
                                use_rms_norm=self.is_falcon_mamba,
54
                                rms_norm_has_weight=not self.is_falcon_mamba,
55
                                rms_norm_eps=mixer_rms_eps,
56
57
                                activation=config.hidden_act)

58
59
60
61
62
63
64
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
65
        mamba_cache_params: MambaCacheParams,
66
67
68
69
70
71
72
73
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

74
75
        hidden_states = self.mixer(hidden_states, attn_metadata,
                                   mamba_cache_params)
76
77
78
79
80
        return hidden_states, residual


class MambaModel(nn.Module):

81
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
82
        super().__init__()
83
84
85
86
87
88

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

89
90
91
92
93
94
95
96
97
98
99
100
101
        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

102
103
104
105
106
107
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: MambaDecoderLayer(
                config, cache_config=cache_config, quant_config=quant_config),
            prefix=f"{prefix}.layers")

108
109
        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
110
111
112
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
113

114
115
116
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

117
118
119
120
121
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        attn_metadata: AttentionMetadata,
122
        mamba_cache_params: MambaCacheParams,
123
        intermediate_tensors: Optional[IntermediateTensors] = None,
124
        inputs_embeds: Optional[torch.Tensor] = None,
125
    ) -> torch.Tensor:
126
127
128
129
130
131
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
132
        else:
133
134
135
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
136

137
        for i in range(self.start_layer, self.end_layer):
138
139
140
141
142
143
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                attn_metadata=attn_metadata,
                residual=residual,
144
145
146
147
148
149
150
                mamba_cache_params=mamba_cache_params.at_layer_idx(
                    i - self.start_layer))
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
151
152
153
154
155
        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states


156
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
157

158
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
159
160
161
162
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
163
164
165
166
167
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
168
        self.vllm_config = vllm_config
169
        self.scheduler_config = scheduler_config
170
        self.model_config = vllm_config.model_config
171
172
        self.backbone = MambaModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "backbone"))
173
174
175
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
176
177
178
179
180
181
182
183
184
185
186
187
        if config.tie_word_embeddings:
            self.lm_head = self.backbone.embeddings
        else:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
            )
188
189
190
191
192
193

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
194
        self.sampler = get_sampler()
195

196
197
        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)
198
199
200
201
202
203
204
205
206
207
208
        if self.scheduler_config is not None and \
            not self.model_config.enforce_eager:
            if self.scheduler_config.max_num_seqs > \
                vllm_config.compilation_config.max_capture_size:
                self.max_batch_size = \
                    vllm_config.compilation_config.max_capture_size
            else:
                self.max_batch_size = vllm_config.pad_for_cudagraph(
                    self.scheduler_config.max_num_seqs)
        else:
            self.max_batch_size = 8192 + 2
209

210
211
212
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

213
214
215
216
217
218
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[KVCache],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
219
                inputs_embeds: Optional[torch.Tensor] = None,
220
221
                **kwargs):
        if self.mamba_cache is None:
222
223
            num_mamba_layers = self.model_config.get_num_layers_by_block_type(
                self.vllm_config.parallel_config, LayerBlockType.mamba)
224
            self.mamba_cache = MambaCacheManager(
225
226
                self.lm_head.weight.dtype, num_mamba_layers,
                self.max_batch_size, *self._get_mamba_cache_shape())
227

228
229
230
231
232
233
234
235
236
        (
            mamba_cache_tensors,
            state_indices_tensor,
        ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
                                                 **kwargs)

        mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
                                              mamba_cache_tensors[1],
                                              state_indices_tensor)
237

238
        hidden_states = self.backbone(input_ids, positions, attn_metadata,
239
240
                                      mamba_cache_params, intermediate_tensors,
                                      inputs_embeds)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def _get_mamba_cache_shape(
            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
        world_size = get_tensor_model_parallel_world_size()
        conv_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.conv_kernel - 1,
        )
        temporal_state_shape = (
            self.config.intermediate_size // world_size,
            self.config.state_size,
        )
        return conv_state_shape, temporal_state_shape

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

278
279
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
280
        params_dict = dict(self.named_parameters())
281
        loaded_params: Set[str] = set()
282
283
284
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")
285
286
287
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
288
289
            if is_pp_missing_parameter(name, self):
                continue
290
291
292
293
294

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
295
296
            loaded_params.add(name)
        return loaded_params