aquila.py 12.4 KB
Newer Older
shunxing1234's avatar
shunxing1234 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
28
from typing import Any, Dict, List, Optional, Tuple
shunxing1234's avatar
shunxing1234 committed
29
30
31
32
33
34
35

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
36
37
38
39
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
shunxing1234's avatar
shunxing1234 committed
40
from vllm.model_executor.layers.sampler import Sampler
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
shunxing1234's avatar
shunxing1234 committed
43
from vllm.model_executor.parallel_utils.parallel_state import (
44
45
46
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
47
from vllm.sequence import SamplerOutput
shunxing1234's avatar
shunxing1234 committed
48
49
50
51
52
53
54
55
56
57
58
59
from vllm.transformers_utils.configs.aquila import AquilaConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


class AquilaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        linear_method: Optional[LinearMethodBase] = None,
shunxing1234's avatar
shunxing1234 committed
61
62
    ):
        super().__init__()
63
64
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
65
            bias=False,
66
67
68
69
70
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
shunxing1234's avatar
shunxing1234 committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class AquilaRMSNorm(nn.Module):

    def __init__(self, hidden_size, eps=1e-6):
        """
        AquilaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1,
                                                               keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance +
                                                    self.variance_epsilon)

        return (self.weight * hidden_states).to(input_dtype)


class AquilaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
110
111
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
112
        rope_scaling: Optional[Dict[str, Any]] = None,
113
        linear_method: Optional[LinearMethodBase] = None,
shunxing1234's avatar
shunxing1234 committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        assert self.total_num_kv_heads % tp_size == 0
        self.num_kv_heads = self.total_num_kv_heads // tp_size
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
128
129
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
shunxing1234's avatar
shunxing1234 committed
130

131
        self.qkv_proj = QKVParallelLinear(
shunxing1234's avatar
shunxing1234 committed
132
133
            hidden_size,
            self.head_dim,
134
135
            self.total_num_heads,
            self.total_num_kv_heads,
shunxing1234's avatar
shunxing1234 committed
136
            bias=False,
137
            linear_method=linear_method,
shunxing1234's avatar
shunxing1234 committed
138
139
140
141
142
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
143
            linear_method=linear_method,
shunxing1234's avatar
shunxing1234 committed
144
145
146
147
148
        )
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
149
150
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
151
            rotary_dim=self.head_dim,
152
            num_kv_heads=self.num_kv_heads,
153
            rope_scaling=rope_scaling)
shunxing1234's avatar
shunxing1234 committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
        output, _ = self.o_proj(attn_output)
        return output


class AquilaDecoderLayer(nn.Module):

174
175
176
177
178
    def __init__(
        self,
        config: AquilaConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
shunxing1234's avatar
shunxing1234 committed
179
180
        super().__init__()
        self.hidden_size = config.hidden_size
181
        rope_theta = getattr(config, "rope_theta", 10000)
182
        rope_scaling = getattr(config, "rope_scaling", None)
183
184
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
shunxing1234's avatar
shunxing1234 committed
185
186
187
        self.self_attn = AquilaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
188
            num_kv_heads=config.num_key_value_heads,
189
190
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
191
            rope_scaling=rope_scaling,
192
            linear_method=linear_method,
shunxing1234's avatar
shunxing1234 committed
193
194
195
196
197
        )
        self.mlp = AquilaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
198
            linear_method=linear_method,
shunxing1234's avatar
shunxing1234 committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        )
        self.input_layernorm = AquilaRMSNorm(config.hidden_size,
                                             eps=config.rms_norm_eps)
        self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
                                                      eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class AquilaModel(nn.Module):

235
236
237
238
239
    def __init__(
        self,
        config: AquilaConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
shunxing1234's avatar
shunxing1234 committed
240
241
242
243
244
245
246
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
247
        )
shunxing1234's avatar
shunxing1234 committed
248
        self.layers = nn.ModuleList([
249
250
            AquilaDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
shunxing1234's avatar
shunxing1234 committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        ])
        self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)

        return hidden_states


class AquilaForCausalLM(nn.Module):

283
284
285
286
287
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
shunxing1234's avatar
shunxing1234 committed
288
289
        super().__init__()
        self.config = config
290
291
292
        self.linear_method = linear_method
        self.model = AquilaModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
shunxing1234's avatar
shunxing1234 committed
293
294
295
296
297
298
299
300
301
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
302
    ) -> SamplerOutput:
shunxing1234's avatar
shunxing1234 committed
303
304
305
306
307
308
309
310
311
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
312
313
                     load_format: str = "auto",
                     revision: Optional[str] = None):
314
315
316
317
318
319
320
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
shunxing1234's avatar
shunxing1234 committed
321
        ]
322
        params_dict = dict(self.named_parameters())
shunxing1234's avatar
shunxing1234 committed
323
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
324
                model_name_or_path, cache_dir, load_format, revision):
shunxing1234's avatar
shunxing1234 committed
325
326
            if "rotary_emb.inv_freq" in name:
                continue
327
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
shunxing1234's avatar
shunxing1234 committed
328
329
                if weight_name not in name:
                    continue
330
331
332
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
shunxing1234's avatar
shunxing1234 committed
333
                break
334
335
336
337
338
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)