phi.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple
39
40
41
42
43

import torch
from torch import nn
from transformers import PretrainedConfig

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.config import CacheConfig, LoRAConfig
46
from vllm.distributed import get_tensor_model_parallel_world_size
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
53
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
54
from vllm.model_executor.layers.rotary_embedding import get_rope
55
56
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
    ParallelLMHead, VocabParallelEmbedding)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.model_executor.sampling_metadata import SamplingMetadata
60
61
62
63
64
65
66
from vllm.sequence import SamplerOutput


class PhiAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
67
                 cache_config: Optional[CacheConfig] = None,
68
                 quant_config: Optional[QuantizationConfig] = None):
69
70
71
72
73
74
75
76
77
78
79
80
81
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
82
            self.hidden_size,
83
84
            self.head_size,
            self.total_num_heads,
85
            bias=True,
86
            quant_config=quant_config,
87
        )
88
        self.dense = RowParallelLinear(
89
90
            self.hidden_size,
            self.hidden_size,
91
            quant_config=quant_config,
92
93
94
        )

        scaling = self.head_size**-0.5
95
96
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
97
98
99
100
101
102
103
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        self.rotary_emb = get_rope(
105
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
108
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        )
110
111
112
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
113
114
                              cache_config=cache_config,
                              quant_config=quant_config)
115
116
117
118
119

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
120
121
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
122
    ) -> torch.Tensor:
123
        qkv, _ = self.qkv_proj(hidden_states)
124
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        q, k = self.rotary_emb(position_ids, q, k)
126
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
127
        output, _ = self.dense(attn_output)
128
129
130
131
132
133
134
        return output


class PhiMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
135
                 quant_config: Optional[QuantizationConfig] = None):
136
137
138
139
140
141
142
143
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
144
            quant_config=quant_config,
145
146
147
148
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
149
            quant_config=quant_config,
150
        )
151
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
152
153
154
155
156
157
158
159
160
161
162
163

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
164
                 cache_config: Optional[CacheConfig] = None,
165
                 quant_config: Optional[QuantizationConfig] = None):
166
        super().__init__()
167
168
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
169
        self.self_attn = PhiAttention(config, cache_config, quant_config)
170
        self.mlp = PhiMLP(config, quant_config)
171
172
173
174
175

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
176
177
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
178
179
    ) -> torch.Tensor:
        residual = hidden_states
180
181
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
182
183
184
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
185
            attn_metadata=attn_metadata,
186
187
188
189
190
191
192
193
194
195
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
196
                 cache_config: Optional[CacheConfig] = None,
197
                 quant_config: Optional[QuantizationConfig] = None):
198
199
        super().__init__()
        self.config = config
200
        self.quant_config = quant_config
201
202
203
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
204
            PhiLayer(config, cache_config, quant_config)
205
206
            for _ in range(config.num_hidden_layers)
        ])
207
208
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
209
210
211
212
213

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
214
215
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
216
    ) -> torch.Tensor:
217
        hidden_states = self.embed_tokens(input_ids)
218
        for i in range(self.config.num_hidden_layers):
219
            layer = self.layers[i]
220
221
222
223
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
224
                attn_metadata,
225
226
            )

227
        hidden_states = self.final_layernorm(hidden_states)
228

229
        return hidden_states
230
231


232
class PhiForCausalLM(nn.Module):
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "dense",
        "fc1",
        "fc2",
    ]
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
        del lora_config  # Unused.
259
260
        super().__init__()
        self.config = config
261
        self.quant_config = quant_config
262

263
        self.model = PhiModel(config, cache_config, quant_config)
264
265
266
267

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      bias=True)
268
269
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
270
271
272
273
274

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
275
276
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
277
    ) -> torch.Tensor:
278
        hidden_states = self.model(input_ids, positions, kv_caches,
279
                                   attn_metadata)
280

281
282
        return hidden_states

283
284
285
286
287
288
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

289
290
    def sample(
        self,
291
        logits: torch.Tensor,
292
        sampling_metadata: SamplingMetadata,
293
    ) -> Optional[SamplerOutput]:
294
        next_tokens = self.sampler(logits, sampling_metadata)
295
        return next_tokens
296

297
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
298
299
300
301
302
303
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
304
        params_dict = dict(self.named_parameters())
305

306
        for name, loaded_weight in weights:
307
308
309
            if "rotary_emb.inv_freq" in name:
                continue

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)