phi.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple
39
40
41
42
43

import torch
from torch import nn
from transformers import PretrainedConfig

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.config import CacheConfig
46
from vllm.distributed import get_tensor_model_parallel_world_size
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
53
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
54
from vllm.model_executor.layers.rotary_embedding import get_rope
55
56
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
    ParallelLMHead, VocabParallelEmbedding)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.model_executor.sampling_metadata import SamplingMetadata
60
61
62
63
64
65
66
from vllm.sequence import SamplerOutput


class PhiAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
67
                 cache_config: Optional[CacheConfig] = None,
68
                 quant_config: Optional[QuantizationConfig] = None):
69
70
71
72
73
74
75
76
77
78
79
80
81
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
82
            self.hidden_size,
83
84
            self.head_size,
            self.total_num_heads,
85
            bias=True,
86
            quant_config=quant_config,
87
        )
88
        self.dense = RowParallelLinear(
89
90
            self.hidden_size,
            self.hidden_size,
91
            quant_config=quant_config,
92
93
94
        )

        scaling = self.head_size**-0.5
95
96
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
97
98
99
100
101
102
103
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        self.rotary_emb = get_rope(
105
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
108
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        )
110
111
112
113
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
                              cache_config=cache_config)
114
115
116
117
118

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
119
120
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
121
    ) -> torch.Tensor:
122
        qkv, _ = self.qkv_proj(hidden_states)
123
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        q, k = self.rotary_emb(position_ids, q, k)
125
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
126
        output, _ = self.dense(attn_output)
127
128
129
130
131
132
133
        return output


class PhiMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
134
                 quant_config: Optional[QuantizationConfig] = None):
135
136
137
138
139
140
141
142
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
143
            quant_config=quant_config,
144
145
146
147
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
148
            quant_config=quant_config,
149
        )
150
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
151
152
153
154
155
156
157
158
159
160
161
162

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
163
                 cache_config: Optional[CacheConfig] = None,
164
                 quant_config: Optional[QuantizationConfig] = None):
165
        super().__init__()
166
167
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
168
        self.self_attn = PhiAttention(config, cache_config, quant_config)
169
        self.mlp = PhiMLP(config, quant_config)
170
171
172
173
174

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
175
176
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
177
178
    ) -> torch.Tensor:
        residual = hidden_states
179
180
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
181
182
183
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
184
            attn_metadata=attn_metadata,
185
186
187
188
189
190
191
192
193
194
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
195
                 cache_config: Optional[CacheConfig] = None,
196
                 quant_config: Optional[QuantizationConfig] = None):
197
198
        super().__init__()
        self.config = config
199
        self.quant_config = quant_config
200
201
202
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
203
            PhiLayer(config, cache_config, quant_config)
204
205
            for _ in range(config.num_hidden_layers)
        ])
206
207
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
208
209
210
211
212

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
213
214
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
215
    ) -> torch.Tensor:
216
        hidden_states = self.embed_tokens(input_ids)
217
        for i in range(self.config.num_hidden_layers):
218
            layer = self.layers[i]
219
220
221
222
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
223
                attn_metadata,
224
225
            )

226
        hidden_states = self.final_layernorm(hidden_states)
227

228
        return hidden_states
229
230


231
232
233
234
class PhiForCausalLM(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
235
                 cache_config: Optional[CacheConfig] = None,
236
                 quant_config: Optional[QuantizationConfig] = None):
237
238
        super().__init__()
        self.config = config
239
        self.quant_config = quant_config
240

241
        self.model = PhiModel(config, cache_config, quant_config)
242
243
244
245

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      bias=True)
246
247
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
248
249
250
251
252

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
253
254
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
255
    ) -> torch.Tensor:
256
        hidden_states = self.model(input_ids, positions, kv_caches,
257
                                   attn_metadata)
258

259
260
        return hidden_states

261
262
263
264
265
266
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

267
268
    def sample(
        self,
269
        logits: torch.Tensor,
270
        sampling_metadata: SamplingMetadata,
271
    ) -> Optional[SamplerOutput]:
272
        next_tokens = self.sampler(logits, sampling_metadata)
273
        return next_tokens
274

275
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
276
277
278
279
280
281
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
282
        params_dict = dict(self.named_parameters())
283

284
        for name, loaded_weight in weights:
285
286
287
            if "rotary_emb.inv_freq" in name:
                continue

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)