phi.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple
39
40
41

import torch
from torch import nn
42
from transformers import PhiConfig
43

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.config import CacheConfig, LoRAConfig
46
from vllm.distributed import get_tensor_model_parallel_world_size
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
53
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
54
from vllm.model_executor.layers.rotary_embedding import get_rope
55
56
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
    ParallelLMHead, VocabParallelEmbedding)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.model_executor.sampling_metadata import SamplingMetadata
60
from vllm.sequence import IntermediateTensors, SamplerOutput
61

62
63
from .interfaces import SupportsLoRA

64
65
66
67

class PhiAttention(nn.Module):

    def __init__(self,
68
                 config: PhiConfig,
69
                 cache_config: Optional[CacheConfig] = None,
70
                 quant_config: Optional[QuantizationConfig] = None):
71
72
73
74
75
76
77
78
79
80
81
82
83
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
84
            self.hidden_size,
85
86
            self.head_size,
            self.total_num_heads,
87
            bias=True,
88
            quant_config=quant_config,
89
        )
90
        self.dense = RowParallelLinear(
91
92
            self.hidden_size,
            self.hidden_size,
93
            quant_config=quant_config,
94
95
96
        )

        scaling = self.head_size**-0.5
97
98
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
99
100
101
102
103
104
105
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        self.rotary_emb = get_rope(
107
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
110
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        )
112
113
114
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
115
116
                              cache_config=cache_config,
                              quant_config=quant_config)
117
118
119
120
121

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
122
123
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
124
    ) -> torch.Tensor:
125
        qkv, _ = self.qkv_proj(hidden_states)
126
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        q, k = self.rotary_emb(position_ids, q, k)
128
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
129
        output, _ = self.dense(attn_output)
130
131
132
133
134
135
        return output


class PhiMLP(nn.Module):

    def __init__(self,
136
                 config: PhiConfig,
137
                 quant_config: Optional[QuantizationConfig] = None):
138
139
140
141
142
143
144
145
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
146
            quant_config=quant_config,
147
148
149
150
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
151
            quant_config=quant_config,
152
        )
153
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
154
155
156
157
158
159
160
161
162
163
164

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
165
                 config: PhiConfig,
166
                 cache_config: Optional[CacheConfig] = None,
167
                 quant_config: Optional[QuantizationConfig] = None):
168
        super().__init__()
169
170
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
171
        self.self_attn = PhiAttention(config, cache_config, quant_config)
172
        self.mlp = PhiMLP(config, quant_config)
173
174
175
176
177

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
178
179
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
180
181
    ) -> torch.Tensor:
        residual = hidden_states
182
183
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
184
185
186
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
187
            attn_metadata=attn_metadata,
188
189
190
191
192
193
194
195
196
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
197
                 config: PhiConfig,
198
                 cache_config: Optional[CacheConfig] = None,
199
                 quant_config: Optional[QuantizationConfig] = None):
200
201
        super().__init__()
        self.config = config
202
        self.quant_config = quant_config
203
204
205
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
206
            PhiLayer(config, cache_config, quant_config)
207
208
            for _ in range(config.num_hidden_layers)
        ])
209
210
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
211
212
213
214
215

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
216
217
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
218
    ) -> torch.Tensor:
219
        hidden_states = self.embed_tokens(input_ids)
220
        for i in range(self.config.num_hidden_layers):
221
            layer = self.layers[i]
222
223
224
225
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
226
                attn_metadata,
227
228
            )

229
        hidden_states = self.final_layernorm(hidden_states)
230

231
        return hidden_states
232
233


234
class PhiForCausalLM(nn.Module, SupportsLoRA):
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "dense",
        "fc1",
        "fc2",
    ]
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
255
        config: PhiConfig,
256
257
258
259
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
260
        super().__init__()
261

262
        self.config = config
263
264
        self.lora_config = lora_config

265
        self.quant_config = quant_config
266

267
        self.model = PhiModel(config, cache_config, quant_config)
268
269
270

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
271
272
                                      bias=True,
                                      quant_config=quant_config)
273
274
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
275
276
277
278
279

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
280
281
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
282
        intermediate_tensors: Optional[IntermediateTensors] = None,
283
    ) -> torch.Tensor:
284
        hidden_states = self.model(input_ids, positions, kv_caches,
285
                                   attn_metadata)
286

287
288
        return hidden_states

289
290
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
291
        logits = self.logits_processor(self.lm_head, hidden_states,
292
293
294
                                       sampling_metadata, self.lm_head.bias)
        return logits

295
296
    def sample(
        self,
297
        logits: torch.Tensor,
298
        sampling_metadata: SamplingMetadata,
299
    ) -> Optional[SamplerOutput]:
300
        next_tokens = self.sampler(logits, sampling_metadata)
301
        return next_tokens
302

303
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
304
305
306
307
308
309
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
310
        params_dict = dict(self.named_parameters())
311

312
        for name, loaded_weight in weights:
313
314
315
            if "rotary_emb.inv_freq" in name:
                continue

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)