xverse.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Adapted from
# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
22
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
23
24
25
26
27
28

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
33
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
34
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
35
36
37
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
40
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
46

47
48
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

52
53
54
55
56
57
58
59

class XverseMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        quant_config: Optional[QuantizationConfig] = None,
61
62
63
64
65
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
66
            quant_config=quant_config)
67
68
69
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
70
                                           quant_config=quant_config)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate, _ = self.gate_up_proj(x)
        x = self.act_fn(gate)
        x, _ = self.down_proj(x)
        return x


class XverseAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
93
        quant_config: Optional[QuantizationConfig] = None,
94
        bias: bool = False,
95
        cache_config: Optional[CacheConfig] = None,
96
        prefix: str = "",
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        # partition the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=bias,
121
            quant_config=quant_config,
122
123
124
125
126
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=bias,
127
            quant_config=quant_config,
128
129
130
131
132
133
134
135
136
137
138
139
140
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
141
                              cache_config=cache_config,
142
143
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class XverseDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
165
        cache_config: Optional[CacheConfig] = None,
166
        quant_config: Optional[QuantizationConfig] = None,
167
        prefix: str = "",
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = XverseAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
183
            quant_config=quant_config,
184
            bias=getattr(config, "bias", False),
185
            cache_config=cache_config,
186
            prefix=f"{prefix}.self_attn",
187
188
189
190
191
        )
        self.mlp = XverseMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
192
            quant_config=quant_config,
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


228
@support_torch_compile
229
230
class XverseModel(nn.Module):

231
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
232
        super().__init__()
233
234
235
236
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
237
238
239
240
241
242
243
244
245
246
247
        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )
248
249
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
250
251
            lambda prefix: XverseDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
252
253
            prefix=f"{prefix}.layers",
        )
254
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255
256
257
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
258

259
260
261
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

262
263
264
265
266
267
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
268
        intermediate_tensors: Optional[IntermediateTensors],
269
        inputs_embeds: Optional[torch.Tensor] = None,
270
271
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
272
273
274
275
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
276
277
278
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
279
            residual = intermediate_tensors["residual"]
280
        for i in range(self.start_layer, self.end_layer):
281
282
283
284
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
285
                kv_caches[i - self.start_layer],
286
287
288
                attn_metadata,
                residual,
            )
289
290
291
292
293
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
294
295
296
297
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


298
class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

326
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
327
        super().__init__()
328

329
330
331
332
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

333
        self.config = config
334
335
        self.lora_config = lora_config

336
        self.quant_config = quant_config
337
338
        self.model = XverseModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
339
340
341
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
342
343
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
344
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
345
        self.sampler = get_sampler()
346
347
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
348

349
350
351
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

352
353
354
355
356
357
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
358
        intermediate_tensors: Optional[IntermediateTensors] = None,
359
        inputs_embeds: Optional[torch.Tensor] = None,
360
    ) -> Union[torch.Tensor, IntermediateTensors]:
361
        hidden_states = self.model(input_ids, positions, kv_caches,
362
363
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
364
365
        return hidden_states

366
367
368
369
370
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
371
        logits = self.logits_processor(self.lm_head, hidden_states,
372
373
374
375
376
377
378
379
380
381
382
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

383
384
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
385
386
387
388
389
390
391
392
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
393
        loaded_params: Set[str] = set()
394
        for name, loaded_weight in weights:
395
396
397
398
399
400
401
402
403
404
405
            if ("rotary_emb.inv_freq" in name
                    or "rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
406
407
                if is_pp_missing_parameter(name, self):
                    continue
408
409
410
411
412
413
414
415
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
416
417
                if is_pp_missing_parameter(name, self):
                    continue
418
419
420
421
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
422
423
            loaded_params.add(name)
        return loaded_params