lfm2.py 18.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
4
from itertools import islice
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import Lfm2Config

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
14
from vllm.model_executor.layers.attention import Attention
15
from vllm.model_executor.layers.layernorm import RMSNorm
16
17
18
19
20
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
21
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
23
24
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
25
26
27
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
28
29
30
31
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
33
34
    ParallelLMHead,
    VocabParallelEmbedding,
)
35
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

38
39
40
41
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
42
    WeightsMapper,
43
44
45
46
47
48
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
49
50
51
52
53
54
55
56
57


class Lfm2MLP(nn.Module):
    def __init__(
        self,
        dim: int,
        ff_dim: int,
        multiple_of: int,
        auto_adjust_ff_dim: bool,
58
59
        ffn_dim_multiplier: float | None,
        quant_config: QuantizationConfig | None = None,
60
61
62
63
64
65
66
67
68
69
        prefix: str = "",
    ):
        super().__init__()
        if auto_adjust_ff_dim:
            ff_dim = int(2 * ff_dim / 3)
            # custom dim factor multiplier
            if ffn_dim_multiplier is not None:
                ff_dim = int(ffn_dim_multiplier * ff_dim)
            ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)

70
        self.w13 = MergedColumnParallelLinear(
71
72
73
74
            input_size=dim,
            output_sizes=[ff_dim] * 2,
            bias=False,
            quant_config=quant_config,
75
            prefix=f"{prefix}.w13",
76
77
78
79
80
81
        )
        self.w2 = RowParallelLinear(
            input_size=ff_dim,
            output_size=dim,
            bias=False,
            quant_config=quant_config,
Paul Pak's avatar
Paul Pak committed
82
            prefix=f"{prefix}.w2",
83
84
85
86
        )
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
87
        gate_up, _ = self.w13(x)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class Lfm2Attention(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
102
103
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        self.num_kv_heads = num_kv_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
150
            rope_parameters=config.rope_parameters,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )
        self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
        self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        n_tokens, _ = hidden_states.shape
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous()
        k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous()
        q = self.q_layernorm(q)
        k = self.k_layernorm(k)
        q, k = self.rotary_emb(positions, q, k)
        q = q.view(n_tokens, self.num_heads * self.head_dim)
        k = k.view(n_tokens, self.num_kv_heads * self.head_dim)
        attn_output = self.attn(q, k, v)
        output, _ = self.out_proj(attn_output)
        return output


class Lfm2AttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
189
190
191
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
192
193
194
195
196
197
198
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.prefix = prefix
        self.config = config
        self.layer_idx = layer_idx

199
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        self.self_attn = Lfm2Attention(
            config=config,
            layer_idx=layer_idx,
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
            ff_dim=config.block_ff_dim,
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
229
        residual: torch.Tensor | None,
230
231
232
233
234
235
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
236
237
            hidden_states, residual = self.operator_norm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
238
239
240
241
242
243
244
245
246
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        return self.feed_forward(hidden_states), residual


class Lfm2ShortConvDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
247
248
249
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
250
251
252
253
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
254
        self.short_conv = ShortConv(
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            config=config,
            dim=config.conv_dim,
            layer_idx=layer_idx,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.conv",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
            ff_dim=config.block_ff_dim,
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
278
        residual: torch.Tensor | None,
279
280
281
282
283
284
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
285
            hidden_states, residual = self.operator_norm(hidden_states, residual)
286
        output = torch.empty_like(hidden_states)
287
        self.short_conv(
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
            hidden_states,
            output,
        )
        hidden_states, residual = self.ffn_norm(output, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Lfm2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
307
308

        self.vocab_size = config.vocab_size
309
310

        self.embed_tokens = VocabParallelEmbedding(
311
312
            self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size
        )
313
314
315
316

        def get_layer(prefix: str):
            layer_idx = extract_layer_index(prefix)
            is_attn = self.config.layer_types[layer_idx] == "full_attention"
317
318
319
            layer_class = (
                Lfm2AttentionDecoderLayer if is_attn else Lfm2ShortConvDecoderLayer
            )
320
321
322
323
324
325
326
327
328
329
            return layer_class(
                config,
                layer_idx,
                model_config,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
330
331
332
333
334
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
335
336

        if get_pp_group().is_last_rank:
337
            self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
338
339
340
        else:
            self.embedding_norm = PPMissingLayer()

341
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
342
343
344
345
        return self.embed_tokens(input_ids)

    def forward(
        self,
346
        input_ids: torch.Tensor | None,
347
        positions: torch.Tensor,
348
349
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
350
351
352
353
354
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
355
                hidden_states = self.embed_input_ids(input_ids)
356
357
358
359
360
361
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

362
        for layer in islice(self.layers, self.start_layer, self.end_layer):
363
364
365
366
367
368
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
        if not get_pp_group().is_last_rank:
369
370
371
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
372
373
374
        hidden_states, _ = self.embedding_norm(hidden_states, residual)
        return hidden_states

375
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
376
377
378
379
        stacked_params_mapping = [
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
380
381
            (".w13", ".w1", 0),
            (".w13", ".w3", 1),
382
383
384
385
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
386
387
388
            if ".conv." in name:
                name = name.replace(".conv.", ".short_conv.", 1)

389
            for param_name, weight_name, shard_id in stacked_params_mapping:
390
391
392
                # Use segment-boundary matching (trailing dot) to prevent
                # e.g. ".w1" from matching inside ".w13" in pre-fused keys.
                if weight_name + "." not in name:
393
                    continue
394
                name = name.replace(weight_name + ".", param_name + ".")
395
396
397
398
399
400
401
402
403
404
405

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
406
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
407
408
409
410
411
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


412
413
414
class Lfm2ForCausalLM(
    nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
):
415
416
417
418
419
420
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
421
        "w13": [
422
423
424
            "w1",
            "w3",
        ],
425
        "in_proj": ["in_proj"],
426
427
    }

428
429
430
431
432
433
434
    # HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
    # with the inner .conv.conv child (ShortConv has a child self.conv, so
    # naming the container .conv too makes _match_target_modules match both)
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".conv.": ".short_conv."},
    )

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, ...]:
        return MambaStateDtypeCalculator.short_conv_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int]]:
456
        """Calculate shapes for LFM2's convolutional cache.
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.short_conv_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.conv_dim,
            conv_kernel=hf_config.conv_L_cache,
        )

474
475
476
477
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.short_conv_state_copy_func()

478
479
480
481
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config
482
483
484
485
486
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Lfm2 currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
487
488
489

        super().__init__()
        self.config = config
490
491
492
        self.model = Lfm2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
493
494
495

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
496
                config.vocab_size,
497
498
499
500
501
502
503
504
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
        else:
            self.lm_head = PPMissingLayer()

505
        self.logits_processor = LogitsProcessor(config.vocab_size)
506
507

        self.make_empty_intermediate_tensors = (
508
509
            self.model.make_empty_intermediate_tensors
        )
510

511
512
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
513

514
515
    def forward(
        self,
516
        input_ids: torch.Tensor | None,
517
        positions: torch.Tensor,
518
519
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
520
521
        **kwargs,
    ) -> torch.Tensor:
522
523
524
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
525
526
        return hidden_states

527
528
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
529
530
        return logits

531
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
532
533
        loader = AutoWeightsLoader(
            self,
534
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
535
        )
536
        return loader.load_weights(weights)