"vscode:/vscode.git/clone" did not exist on "4fd4b743a23cc6ccbd832f11be12317a8c2f0fbc"
lfm2.py 19 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
4
from itertools import islice
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import Lfm2Config

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
14
from vllm.model_executor.layers.attention import Attention
15
from vllm.model_executor.layers.layernorm import RMSNorm
16
17
18
19
20
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
21
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
23
24
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
25
26
27
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
28
29
30
31
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
33
34
    ParallelLMHead,
    VocabParallelEmbedding,
)
35
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

38
39
40
41
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
42
    WeightsMapper,
43
44
45
46
47
48
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
49
50
51
52
53
54


class Lfm2MLP(nn.Module):
    def __init__(
        self,
        dim: int,
55
        intermediate_size: int,
56
57
        multiple_of: int,
        auto_adjust_ff_dim: bool,
58
59
        ffn_dim_multiplier: float | None,
        quant_config: QuantizationConfig | None = None,
60
61
62
63
        prefix: str = "",
    ):
        super().__init__()
        if auto_adjust_ff_dim:
64
            intermediate_size = int(2 * intermediate_size / 3)
65
66
            # custom dim factor multiplier
            if ffn_dim_multiplier is not None:
67
68
69
70
                intermediate_size = int(ffn_dim_multiplier * intermediate_size)
            intermediate_size = multiple_of * (
                (intermediate_size + multiple_of - 1) // multiple_of
            )
71

72
        self.w13 = MergedColumnParallelLinear(
73
            input_size=dim,
74
            output_sizes=[intermediate_size] * 2,
75
76
            bias=False,
            quant_config=quant_config,
77
            prefix=f"{prefix}.w13",
78
79
        )
        self.w2 = RowParallelLinear(
80
            input_size=intermediate_size,
81
82
83
            output_size=dim,
            bias=False,
            quant_config=quant_config,
Paul Pak's avatar
Paul Pak committed
84
            prefix=f"{prefix}.w2",
85
86
87
88
        )
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
89
        gate_up, _ = self.w13(x)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class Lfm2Attention(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
104
105
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        self.num_kv_heads = num_kv_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
152
            rope_parameters=config.rope_parameters,
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )
        self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
        self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        n_tokens, _ = hidden_states.shape
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous()
        k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous()
        q = self.q_layernorm(q)
        k = self.k_layernorm(k)
        q, k = self.rotary_emb(positions, q, k)
        q = q.view(n_tokens, self.num_heads * self.head_dim)
        k = k.view(n_tokens, self.num_kv_heads * self.head_dim)
        attn_output = self.attn(q, k, v)
        output, _ = self.out_proj(attn_output)
        return output


class Lfm2AttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
191
192
193
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
194
195
196
197
198
199
200
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.prefix = prefix
        self.config = config
        self.layer_idx = layer_idx

201
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

        self.self_attn = Lfm2Attention(
            config=config,
            layer_idx=layer_idx,
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
217
            intermediate_size=config.intermediate_size,
218
219
220
221
222
223
224
225
226
227
228
229
230
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
231
        residual: torch.Tensor | None,
232
233
234
235
236
237
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
238
239
            hidden_states, residual = self.operator_norm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
240
241
242
243
244
245
246
247
248
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        return self.feed_forward(hidden_states), residual


class Lfm2ShortConvDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
249
250
251
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
252
253
254
255
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
256
        self.short_conv = ShortConv(
257
258
259
260
261
262
263
264
265
266
            config=config,
            dim=config.conv_dim,
            layer_idx=layer_idx,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.conv",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
267
            intermediate_size=config.intermediate_size,
268
269
270
271
272
273
274
275
276
277
278
279
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
280
        residual: torch.Tensor | None,
281
282
283
284
285
286
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
287
            hidden_states, residual = self.operator_norm(hidden_states, residual)
288
        output = torch.empty_like(hidden_states)
289
        self.short_conv(
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            hidden_states,
            output,
        )
        hidden_states, residual = self.ffn_norm(output, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Lfm2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
309
310

        self.vocab_size = config.vocab_size
311
312

        self.embed_tokens = VocabParallelEmbedding(
313
314
            self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size
        )
315
316
317
318

        def get_layer(prefix: str):
            layer_idx = extract_layer_index(prefix)
            is_attn = self.config.layer_types[layer_idx] == "full_attention"
319
320
321
            layer_class = (
                Lfm2AttentionDecoderLayer if is_attn else Lfm2ShortConvDecoderLayer
            )
322
323
324
325
326
327
328
329
330
331
            return layer_class(
                config,
                layer_idx,
                model_config,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
332
333
334
335
336
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
337
338

        if get_pp_group().is_last_rank:
339
            self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
340
341
342
        else:
            self.embedding_norm = PPMissingLayer()

343
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
344
345
346
347
        return self.embed_tokens(input_ids)

    def forward(
        self,
348
        input_ids: torch.Tensor | None,
349
        positions: torch.Tensor,
350
351
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
352
353
354
355
356
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
357
                hidden_states = self.embed_input_ids(input_ids)
358
359
360
361
362
363
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

364
        for layer in islice(self.layers, self.start_layer, self.end_layer):
365
366
367
368
369
370
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
        if not get_pp_group().is_last_rank:
371
372
373
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
374
375
376
        hidden_states, _ = self.embedding_norm(hidden_states, residual)
        return hidden_states

377
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
378
379
380
381
        stacked_params_mapping = [
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
382
383
            (".w13", ".w1", 0),
            (".w13", ".w3", 1),
384
385
386
387
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
388
389
390
            if ".conv." in name:
                name = name.replace(".conv.", ".short_conv.", 1)

391
            for param_name, weight_name, shard_id in stacked_params_mapping:
392
393
394
                # Use segment-boundary matching (trailing dot) to prevent
                # e.g. ".w1" from matching inside ".w13" in pre-fused keys.
                if weight_name + "." not in name:
395
                    continue
396
                name = name.replace(weight_name + ".", param_name + ".")
397
398
399
400
401
402
403
404
405
406
407

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
408
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
409
410
411
412
413
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


414
415
416
class Lfm2ForCausalLM(
    nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
):
417
418
419
420
421
422
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
423
        "w13": [
424
425
426
            "w1",
            "w3",
        ],
427
        "in_proj": ["in_proj"],
428
429
    }

430
431
432
433
434
435
436
    # HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
    # with the inner .conv.conv child (ShortConv has a child self.conv, so
    # naming the container .conv too makes _match_target_modules match both)
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".conv.": ".short_conv."},
    )

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, ...]:
        return MambaStateDtypeCalculator.short_conv_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int]]:
458
        """Calculate shapes for LFM2's convolutional cache.
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.short_conv_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.conv_dim,
            conv_kernel=hf_config.conv_L_cache,
        )

476
477
478
479
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.short_conv_state_copy_func()

480
481
482
483
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config
484
485
486
487
488
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Lfm2 currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
489
490
491

        super().__init__()
        self.config = config
492
493
494
        self.model = Lfm2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
495
496
497

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
498
                config.vocab_size,
499
500
501
502
503
504
505
506
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
        else:
            self.lm_head = PPMissingLayer()

507
        self.logits_processor = LogitsProcessor(config.vocab_size)
508
509

        self.make_empty_intermediate_tensors = (
510
511
            self.model.make_empty_intermediate_tensors
        )
512

513
514
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
515

516
517
    def forward(
        self,
518
        input_ids: torch.Tensor | None,
519
        positions: torch.Tensor,
520
521
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
522
523
        **kwargs,
    ) -> torch.Tensor:
524
525
526
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
527
528
        return hidden_states

529
530
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
531
532
        return logits

533
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
534
535
        loader = AutoWeightsLoader(
            self,
536
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
537
        )
538
        return loader.load_weights(weights)