exaone4.py 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

# Adapted from
# https://github.com/lgai-exaone/transformers/blob/add-exaone4/src/transformers/models/exaone4/modeling_exaone4.py
# Copyright 2025 The LG CNS Gen AI Solution Delivery Team.
# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""

from collections.abc import Iterable
25
from itertools import islice
26
27
28

import torch
from torch import nn
29
from transformers import Exaone4Config
30

31
from vllm.attention.layer import Attention
32
33
34
35
36
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
37
38
39
40
41
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
42
43
44
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
47
48
    ParallelLMHead,
    VocabParallelEmbedding,
)
49
from vllm.model_executor.model_loader.weight_utils import (
50
51
52
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
53
from vllm.sequence import IntermediateTensors
54
from vllm.transformers_utils.config import set_default_rope_theta
55
56

from .interfaces import SupportsLoRA, SupportsPP
57
58
59
60
61
62
63
64
65
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
66
67
68
69
70
71
72
73


class Exaone4GatedMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
74
        quant_config: QuantizationConfig | None = None,
Kyungmin Lee's avatar
Kyungmin Lee committed
75
        reduce_results: bool = True,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        bias: bool = False,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
Kyungmin Lee's avatar
Kyungmin Lee committed
92
            reduce_results=reduce_results,
93
94
95
            prefix=f"{prefix}.down_proj",
        )
        if hidden_act != "silu":
96
97
98
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Exaone4Attention(nn.Module):
    def __init__(
        self,
111
        config: Exaone4Config,
112
113
114
115
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
116
        quant_config: QuantizationConfig | None = None,
117
        bias: bool = False,
118
        cache_config: CacheConfig | None = None,
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)

        is_neox_style = True
        if quant_config is not None and quant_config.get_name() == "gguf":
            is_neox_style = False

        layer_idx = extract_layer_index(prefix)
172
173
        is_sliding = config.layer_types[layer_idx] == "sliding_attention"
        self.sliding_window = config.sliding_window if is_sliding else None
174

175
176
        # apply rotary embeddings to every layer in full attention models
        self.apply_rope_all_layers = "sliding_attention" not in config.layer_types
177

178
        set_default_rope_theta(config, default_theta=1000000)
179
180
181
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
182
            rope_parameters=config.rope_parameters,
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            is_neox_style=is_neox_style,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=self.sliding_window,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.unflatten(-1, (self.num_heads, self.head_dim))
        q = self.q_norm(q)
        q = q.flatten(-2, -1)
        k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
        k = self.k_norm(k)
        k = k.flatten(-2, -1)

211
        if self.sliding_window or self.apply_rope_all_layers:
212
213
214
215
216
217
218
219
220
            q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Exaone4DecoderLayer(nn.Module):
    def __init__(
        self,
221
        config: Exaone4Config,
222
223
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
224
225
226
227
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
228
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
229
230
231
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
232
233
            config, "bias", False
        )
234
235
236
237
238

        self.self_attn = Exaone4Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
239
240
241
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=attention_bias,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = Exaone4GatedMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
            prefix=f"{prefix}.mlp",
        )
256
257
258
259
260
261
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.post_feedforward_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
262
263
264
265
266

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
267
        residual: torch.Tensor | None,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    ) -> tuple[torch.Tensor, torch.Tensor]:
        residual = hidden_states

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Use post-LN
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        # Use post-LN
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


@support_torch_compile
class Exaone4Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
304
305

        self.vocab_size = config.vocab_size
306
307
308
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Exaone4DecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

331
332
333
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
334

335
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
336
337
338
339
        return self.embed_tokens(input_ids)

    def forward(
        self,
340
        input_ids: torch.Tensor | None,
341
        positions: torch.Tensor,
342
343
344
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
345
346
347
348
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
349
                hidden_states = self.embed_input_ids(input_ids)
350
351
352
353
354
355
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

356
        for layer in islice(self.layers, self.start_layer, self.end_layer):
357
358
359
360
361
362
363
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        if not get_pp_group().is_last_rank:
364
365
366
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
367
368
369
370

        hidden_states = self.norm(hidden_states)
        return hidden_states

371
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
372
373
374
375
376
377
378
379
380
381
382
383
384
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
385
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
386
387
388
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
389
390
391
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
392
393
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
394
395
396
397
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
430
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config

        self.model = Exaone4Model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
469
                config.vocab_size,
470
471
                config.hidden_size,
                quant_config=quant_config,
472
                prefix=maybe_prefix(prefix, "lm_head"),
473
474
475
476
477
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
478
            self.logits_processor = LogitsProcessor(
479
                config.vocab_size, scale=logit_scale
480
            )
481
482
483
484
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
485
486
            self.model.make_empty_intermediate_tensors
        )
487

488
489
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
490
491
492

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
493
        input_ids: torch.Tensor,
494
        positions: torch.Tensor,
495
496
497
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
498
499
500
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
501
502
503
504
505
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
506
    ) -> torch.Tensor | None:
507
        logits = self.logits_processor(self.lm_head, hidden_states)
508
509
        return logits

510
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
511
512
513
514
515
        loader = AutoWeightsLoader(
            self,
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
516
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
517
        )
zhuwenwen's avatar
zhuwenwen committed
518
        return loader.load_weights(weights)