exaone.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Adapted from
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
# Copyright 2024 The LG U+ CTO AI Tech Lab.
# Copyright 2021 The LG AI Research EXAONE Lab
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""

28
from collections.abc import Iterable
29
from itertools import islice
30
31
32

import torch
from torch import nn
33
from transformers import PretrainedConfig
34

35
from vllm.attention import Attention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
39
40
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
49
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
51
52
    ParallelLMHead,
    VocabParallelEmbedding,
)
53
from vllm.model_executor.model_loader.weight_utils import (
54
55
56
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
57
from vllm.sequence import IntermediateTensors
58

59
from .interfaces import SupportsLoRA, SupportsPP
60
61
62
63
64
65
66
67
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
68
69
70
71
72
73
74
75


class ExaoneGatedMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
76
        quant_config: QuantizationConfig | None = None,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        bias: bool = False,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.c_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )
        if hidden_act != "silu":
96
97
98
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class ExaoneAttention(nn.Module):
    def __init__(
        self,
111
        config: PretrainedConfig,
112
113
114
115
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
116
        quant_config: QuantizationConfig | None = None,
117
        bias: bool = False,
118
        cache_config: CacheConfig | None = None,
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
138
139
140
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.out_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        is_neox_style = True
        if quant_config is not None and quant_config.get_name() == "gguf":
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
172
            rope_parameters=config.rope_parameters,
173
174
175
176
177
178
179
180
181
            is_neox_style=is_neox_style,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
182
            prefix=f"{prefix}.attn",
183
184
185
186
187
188
189
190
191
192
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
193
        attn_output = self.attn(q, k, v)
194
195
196
197
198
199
200
        output, _ = self.out_proj(attn_output)
        return output


class ExaoneBlockAttention(nn.Module):
    def __init__(
        self,
201
        config: PretrainedConfig,
202
203
204
205
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
206
        quant_config: QuantizationConfig | None = None,
207
        bias: bool = False,
208
        cache_config: CacheConfig | None = None,
209
210
211
212
213
214
215
216
217
218
219
220
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.attention = ExaoneAttention(
            config=config,
            hidden_size=hidden_size,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=bias,
            cache_config=cache_config,
221
            prefix=f"{prefix}.attention",
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        return self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )


class ExaoneDecoderLayer(nn.Module):
    def __init__(
        self,
238
        config: PretrainedConfig,
239
240
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
241
242
243
244
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
245
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
246
247
248
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
249
250
            config, "bias", False
        )
251
252
253
254
        self.attn = ExaoneBlockAttention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
255
256
257
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=attention_bias,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )
        self.mlp = ExaoneGatedMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.activation_function,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
            prefix=f"{prefix}.mlp",
        )
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
279
        residual: torch.Tensor | None,
280
    ) -> tuple[torch.Tensor, torch.Tensor]:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


298
@support_torch_compile
299
class ExaoneModel(nn.Module):
300
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
301
        super().__init__()
302
303
304
305
306

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

307
        self.config = config
308
        self.quant_config = quant_config
309
310

        self.vocab_size = config.vocab_size
311
        self.wte = config.vocab_size
312
313
314
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
            self.wte = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )
        else:
            self.wte = PPMissingLayer()
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: ExaoneDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.h",
        )
        if get_pp_group().is_last_rank:
333
            self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
334
335
336
        else:
            self.ln_f = PPMissingLayer()

337
338
339
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
340

341
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
342
343
344
345
        return self.wte(input_ids)

    def forward(
        self,
346
        input_ids: torch.Tensor | None,
347
        positions: torch.Tensor,
348
349
350
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
351
352
353
354
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
355
                hidden_states = self.embed_input_ids(input_ids)
356
357
358
359
360
361
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

362
        for layer in islice(self.h, self.start_layer, self.end_layer):
363
364
365
366
367
368
369
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        if not get_pp_group().is_last_rank:
370
371
372
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
373
374
375
376

        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states

377
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
378
379
380
381
382
383
384
385
386
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".c_fc_0", 0),
            (".gate_up_proj", ".c_fc_1", 1),
        ]
        params_dict = dict(self.named_parameters())
387
        loaded_params: set[str] = set()
388
389
390
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
391
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
392
393
394
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
395
396
397
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
398
399
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
400
401
402
403
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
436
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
437
438
439
440
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

441

442
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "c_fc_0",
            "c_fc_1",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

462
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
463
        super().__init__()
464
465
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
466
467

        self.config = config
468

469
        self.quant_config = quant_config
470
471

        self.transformer = ExaoneModel(
472
473
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
474
475
476
        )
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
477
                config.vocab_size,
478
479
                config.hidden_size,
                quant_config=quant_config,
480
                prefix=maybe_prefix(prefix, "lm_head"),
481
482
483
484
485
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.transformer.wte.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
486
            self.logits_processor = LogitsProcessor(
487
                config.vocab_size, scale=logit_scale
488
            )
489
490
491
        else:
            self.lm_head = PPMissingLayer()

492
        self.make_empty_intermediate_tensors = (
493
494
            self.transformer.make_empty_intermediate_tensors
        )
495

496
497
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
498

499
500
501
502
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
503
504
505
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
506
507
508
        model_output = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
509
510
511
512
513
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
514
    ) -> torch.Tensor | None:
515
        logits = self.logits_processor(self.lm_head, hidden_states)
516
517
        return logits

518
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
519
520
        loader = AutoWeightsLoader(
            self,
521
522
523
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
524
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
525
526
        )
        return loader.load_weights(weights)