"tests/vscode:/vscode.git/clone" did not exist on "8ecb3e9e9336ce47e47b61417e24161b38079e93"
solar.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""

27
from collections.abc import Iterable
28
from typing import Any
29
30
31

import torch
from torch import nn
32
from transformers import PretrainedConfig
33

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
39
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
47
48
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
50
51
    ParallelLMHead,
    VocabParallelEmbedding,
)
52
from vllm.model_executor.model_loader.weight_utils import (
53
54
55
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
56
57
from vllm.sequence import IntermediateTensors

58
from .interfaces import SupportsLoRA, SupportsPP
59
60
61
62
63
64
65
66
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
67

68
69
70
71
72
73
74

class SolarMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
75
        quant_config: QuantizationConfig | None = None,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        bias: bool = False,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        if hidden_act != "silu":
95
96
97
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
98
99
100
101
102
103
104
105
106
107
108
109
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class SolarAttention(nn.Module):
    def __init__(
        self,
110
        config: PretrainedConfig,
111
112
113
114
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
115
        rope_scaling: dict[str, Any] | None = None,
116
        max_position_embeddings: int = 8192,
117
        quant_config: QuantizationConfig | None = None,
118
        bias: bool = False,
119
        cache_config: CacheConfig | None = None,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
139
140
141
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
179
            prefix=f"{prefix}.attn",
180
181
182
183
184
185
186
187
188
189
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
190
        attn_output = self.attn(q, k, v)
191
192
193
194
195
196
197
        output, _ = self.o_proj(attn_output)
        return output


class SolarDecoderLayer(nn.Module):
    def __init__(
        self,
198
        config: PretrainedConfig,
199
200
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
201
202
203
204
205
206
207
208
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)

        if rope_scaling is not None and getattr(
209
210
211
212
213
214
            config, "original_max_position_embeddings", None
        ):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
215
216
217
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
218
219
            config, "bias", False
        )
220
221
222
223
        self.self_attn = SolarAttention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
224
225
226
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=attention_bias,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = SolarMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
            prefix=f"{prefix}.mlp",
        )
243
244
245
246
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
247
248
249
250
251

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
252
        residual: torch.Tensor | None,
253
    ) -> tuple[torch.Tensor, torch.Tensor]:
254
255
256
257
258
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
259
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
260
261
262
263
264
265
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
266
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
267
268
269
270
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


271
@support_torch_compile
272
class SolarModel(nn.Module):
273
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
274
        super().__init__()
275
276
277
278
279

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

280
        self.config = config
281
        self.quant_config = quant_config
282
283
284

        self.vocab_size = config.vocab_size

285
286
287
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: SolarDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

309
310
311
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
312

313
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
314
315
316
317
        return self.embed_tokens(input_ids)

    def forward(
        self,
318
        input_ids: torch.Tensor | None,
319
        positions: torch.Tensor,
320
321
322
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
323
324
325
326
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
327
                hidden_states = self.embed_input_ids(input_ids)
328
329
330
331
332
333
334
335
336
337
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        bskcn_h_1 = None
        bskcn_h_2 = None
        bskcn_r_1 = None
        bskcn_r_2 = None
338
        bskcn_tv = self.config.bskcn_tv[0] if self.training else self.config.bskcn_tv[1]
339
340
341
342
343
344
345
346
347

        for i in range(self.start_layer, self.end_layer):
            if i in self.config.bskcn_1:
                bskcn_h_1 = hidden_states.clone()
                bskcn_r_1 = residual.clone()
            if i in self.config.bskcn_2:
                bskcn_h_2 = hidden_states.clone()
                bskcn_r_2 = residual.clone()
            if i in self.config.bskcn_3:
348
                hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (1 - bskcn_tv)
349
350
                residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv)
            if i in self.config.bskcn_4:
351
                hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (1 - bskcn_tv)
352
353
354
355
356
357
358
359
360
                residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv)
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        if not get_pp_group().is_last_rank:
361
362
363
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
364
365
366
367

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

368
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
369
370
371
372
373
374
375
376
377
378
379
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
380
381
382
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
383
384
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
385
386
387
388
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
421
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
422
423
424
425
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

426

427
class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

447
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448
        super().__init__()
449
450
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
451

452
        self.config = config
453

454
        self.quant_config = quant_config
455
456

        self.model = SolarModel(
457
458
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
459
460
461
        )
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
462
                config.vocab_size,
463
464
                config.hidden_size,
                quant_config=quant_config,
465
                prefix=maybe_prefix(prefix, "lm_head"),
466
467
468
469
470
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
471
            self.logits_processor = LogitsProcessor(
472
                config.vocab_size, scale=logit_scale
473
            )
474
475
476
        else:
            self.lm_head = PPMissingLayer()

477
        self.make_empty_intermediate_tensors = (
478
479
            self.model.make_empty_intermediate_tensors
        )
480

481
482
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
483

484
485
486
487
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
488
489
490
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
491
492
493
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
494
495
        return model_output

496
497
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
498
499
        return logits

500
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
501
        loader = AutoWeightsLoader(self)
502
        return loader.load_weights(weights)