"vllm/vscode:/vscode.git/clone" did not exist on "8cd174fa358326d5cc4195446be2ebcd65c481ce"
gpt_bigcode.py 9.87 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import SamplerOutput
43
44
45
46


class GPTBigCodeAttention(nn.Module):

47
48
49
    def __init__(
        self,
        config: GPTBigCodeConfig,
50
        cache_config: Optional[CacheConfig] = None,
51
        quant_config: Optional[QuantizationConfig] = None,
52
    ):
53
54
55
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
56
        self.tensor_model_parallel_world_size = (
57
            get_tensor_model_parallel_world_size())
58
59
60
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
61
        self.head_dim = self.hidden_size // total_num_heads
62
        self.scale = self.head_dim**-0.5
63

64
65
        self.multi_query = config.multi_query
        if self.multi_query:
66
            total_num_kv_heads = 1
67
68
            self.num_kv_heads = 1
        else:
69
            total_num_kv_heads = total_num_heads
70
            self.num_kv_heads = self.num_heads
71
72
73
74
75
76
77
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
78
            quant_config=quant_config,
79
        )
80
81
82
83
84

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
85
            quant_config=quant_config,
86
        )
87
88
89
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
90
91
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config)
92
93
94
95

    def forward(
        self,
        hidden_states: torch.Tensor,
96
97
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
98
    ) -> torch.Tensor:
99
100
101
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
102
103
104
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
105
106
            dim=-1,
        )
107
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
108
109
110
111
112
113
114
115
116
117
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
118
        quant_config: Optional[QuantizationConfig] = None,
119
120
121
    ):
        super().__init__()
        hidden_size = config.hidden_size
122
123
124
125
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
126
            quant_config=quant_config,
127
128
129
130
131
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
132
            quant_config=quant_config,
133
        )
134
135
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
136
137
138
139
140
141
142
143
144
145

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

146
147
148
    def __init__(
        self,
        config: GPTBigCodeConfig,
149
        cache_config: Optional[CacheConfig] = None,
150
        quant_config: Optional[QuantizationConfig] = None,
151
    ):
152
153
        super().__init__()
        hidden_size = config.hidden_size
154
155
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
156
157

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
159
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
160
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
161
162
163
164

    def forward(
        self,
        hidden_states: torch.Tensor,
165
166
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
167
168
169
170
171
172
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
173
            attn_metadata=attn_metadata,
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

188
189
190
    def __init__(
        self,
        config: GPTBigCodeConfig,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
    ):
194
195
        super().__init__()
        self.config = config
196
        assert not config.add_cross_attention
197
198
199

        self.embed_dim = config.hidden_size

200
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
201
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
202
        self.h = nn.ModuleList([
203
            GPTBigCodeBlock(config, cache_config, quant_config)
204
205
            for _ in range(config.num_hidden_layers)
        ])
206
207
208
209
210
211
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
212
213
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
214
215
216
217
218
219
220
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
221
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
222
223
224
225
226
227
228

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

229
230
231
    def __init__(
        self,
        config: GPTBigCodeConfig,
232
        cache_config: Optional[CacheConfig] = None,
233
        quant_config: Optional[QuantizationConfig] = None,
234
    ):
235
236
        super().__init__()
        self.config = config
237
        self.quant_config = quant_config
238
        self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
239
        self.lm_head_weight = self.transformer.wte.weight
240
241
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
242
243
244
245
246

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
247
248
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
249
    ) -> torch.Tensor:
250
        hidden_states = self.transformer(input_ids, positions, kv_caches,
251
                                         attn_metadata)
252
253
        return hidden_states

254
255
256
257
258
259
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

260
261
    def sample(
        self,
262
        logits: torch.Tensor,
263
        sampling_metadata: SamplingMetadata,
264
    ) -> Optional[SamplerOutput]:
265
        next_tokens = self.sampler(logits, sampling_metadata)
266
267
        return next_tokens

268
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
        params_dict = dict(self.named_parameters(remove_duplicate=False))
270
        for name, loaded_weight in weights:
271
272
273
274
275
276
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
277
278
279
280
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)