lfm2.py 17.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
4
from itertools import islice
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import Lfm2Config

10
from vllm.attention.layer import Attention
11
12
13
14
15
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
16
17
18
19
20
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
21
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
23
24
25
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
26
27
28
29
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
31
32
    ParallelLMHead,
    VocabParallelEmbedding,
)
33
34
35
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

36
37
38
39
40
41
42
43
44
45
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
46
47
48
49
50
51
52
53
54


class Lfm2MLP(nn.Module):
    def __init__(
        self,
        dim: int,
        ff_dim: int,
        multiple_of: int,
        auto_adjust_ff_dim: bool,
55
56
        ffn_dim_multiplier: float | None,
        quant_config: QuantizationConfig | None = None,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        prefix: str = "",
    ):
        super().__init__()
        if auto_adjust_ff_dim:
            ff_dim = int(2 * ff_dim / 3)
            # custom dim factor multiplier
            if ffn_dim_multiplier is not None:
                ff_dim = int(ffn_dim_multiplier * ff_dim)
            ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)

        self.w1 = MergedColumnParallelLinear(
            input_size=dim,
            output_sizes=[ff_dim] * 2,
            bias=False,
            quant_config=quant_config,
Paul Pak's avatar
Paul Pak committed
72
            prefix=f"{prefix}.w1",
73
74
75
76
77
78
        )
        self.w2 = RowParallelLinear(
            input_size=ff_dim,
            output_size=dim,
            bias=False,
            quant_config=quant_config,
Paul Pak's avatar
Paul Pak committed
79
            prefix=f"{prefix}.w2",
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.w1(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class Lfm2Attention(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
99
100
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        self.num_kv_heads = num_kv_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
148
            rope_parameters=config.rope_parameters,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            is_neox_style=True,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )
        self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
        self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        n_tokens, _ = hidden_states.shape
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous()
        k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous()
        q = self.q_layernorm(q)
        k = self.k_layernorm(k)
        q, k = self.rotary_emb(positions, q, k)
        q = q.view(n_tokens, self.num_heads * self.head_dim)
        k = k.view(n_tokens, self.num_kv_heads * self.head_dim)
        attn_output = self.attn(q, k, v)
        output, _ = self.out_proj(attn_output)
        return output


class Lfm2AttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
187
188
189
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
190
191
192
193
194
195
196
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.prefix = prefix
        self.config = config
        self.layer_idx = layer_idx

197
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

        self.self_attn = Lfm2Attention(
            config=config,
            layer_idx=layer_idx,
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
            ff_dim=config.block_ff_dim,
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
227
        residual: torch.Tensor | None,
228
229
230
231
232
233
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
234
235
            hidden_states, residual = self.operator_norm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
236
237
238
239
240
241
242
243
244
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        return self.feed_forward(hidden_states), residual


class Lfm2ShortConvDecoderLayer(nn.Module):
    def __init__(
        self,
        config: Lfm2Config,
        layer_idx: int,
245
246
247
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.conv = ShortConv(
            config=config,
            dim=config.conv_dim,
            layer_idx=layer_idx,
            model_config=model_config,
            cache_config=cache_config,
            prefix=f"{prefix}.conv",
        )

        self.feed_forward = Lfm2MLP(
            dim=config.block_dim,
            ff_dim=config.block_ff_dim,
            multiple_of=config.block_multiple_of,
            auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
            ffn_dim_multiplier=config.block_ffn_dim_multiplier,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
276
        residual: torch.Tensor | None,
277
278
279
280
281
282
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.operator_norm(hidden_states)
        else:
283
            hidden_states, residual = self.operator_norm(hidden_states, residual)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        output = torch.empty_like(hidden_states)
        self.conv(
            hidden_states,
            output,
        )
        hidden_states, residual = self.ffn_norm(output, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Lfm2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
305
306

        self.vocab_size = config.vocab_size
307
308

        self.embed_tokens = VocabParallelEmbedding(
309
310
            self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size
        )
311
312
313
314

        def get_layer(prefix: str):
            layer_idx = extract_layer_index(prefix)
            is_attn = self.config.layer_types[layer_idx] == "full_attention"
315
316
317
            layer_class = (
                Lfm2AttentionDecoderLayer if is_attn else Lfm2ShortConvDecoderLayer
            )
318
319
320
321
322
323
324
325
326
327
            return layer_class(
                config,
                layer_idx,
                model_config,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
328
329
330
331
332
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
333
334

        if get_pp_group().is_last_rank:
335
            self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
336
337
338
        else:
            self.embedding_norm = PPMissingLayer()

339
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
340
341
342
343
344
345
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
346
347
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
348
349
350
351
352
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
353
                hidden_states = self.embed_input_ids(input_ids)
354
355
356
357
358
359
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

360
        for layer in islice(self.layers, self.start_layer, self.end_layer):
361
362
363
364
365
366
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
        if not get_pp_group().is_last_rank:
367
368
369
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
370
371
372
        hidden_states, _ = self.embedding_norm(hidden_states, residual)
        return hidden_states

373
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        stacked_params_mapping = [
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".w1", ".w1", 0),
            (".w1", ".w3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
399
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
400
401
402
403
404
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


405
406
407
class Lfm2ForCausalLM(
    nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
):
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "w1": [
            "w1",
            "w3",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, ...]:
        return MambaStateDtypeCalculator.short_conv_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int]]:
441
        """Calculate shapes for LFM2's convolutional cache.
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.short_conv_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_config.conv_dim,
            conv_kernel=hf_config.conv_L_cache,
        )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config
463

464
465
466
        assert not cache_config.enable_prefix_caching, (
            "Lfm2 currently does not support prefix caching"
        )
467
468
469

        super().__init__()
        self.config = config
470
471
472
        self.model = Lfm2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
473
474
475

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
476
                config.vocab_size,
477
478
479
480
481
482
483
484
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
        else:
            self.lm_head = PPMissingLayer()

485
        self.logits_processor = LogitsProcessor(config.vocab_size)
486
487

        self.make_empty_intermediate_tensors = (
488
489
            self.model.make_empty_intermediate_tensors
        )
490

491
492
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
493

494
495
496
497
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
498
499
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
500
501
        **kwargs,
    ) -> torch.Tensor:
502
503
504
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
505
506
        return hidden_states

507
508
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
509
510
        return logits

511
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
512
513
        loader = AutoWeightsLoader(
            self,
514
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
515
        )
516
        return loader.load_weights(weights)