orion.py 11.7 KB
Newer Older
张大成's avatar
张大成 committed
1
2
3
4
5
6
# coding=utf-8
# Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
张大成's avatar
张大成 committed
8
9
10
11
12

import torch
from torch import nn
from transformers import PretrainedConfig

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.distributed import get_tensor_model_parallel_world_size
张大成's avatar
张大成 committed
15
from vllm.model_executor.layers.activation import SiluAndMul
16
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
张大成's avatar
张大成 committed
17
18
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
22
from vllm.model_executor.layers.rotary_embedding import get_rope
张大成's avatar
张大成 committed
23
24
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
25
    ParallelLMHead, VocabParallelEmbedding)
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
张大成's avatar
张大成 committed
27
28
29
30
31
32
33
34
35
36
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput


class OrionMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
38
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
39
40
41
42
43
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
44
            quant_config=quant_config)
张大成's avatar
张大成 committed
45
46
47
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
48
                                           quant_config=quant_config)
张大成's avatar
张大成 committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class OrionAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
71
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
102
            quant_config=quant_config,
张大成's avatar
张大成 committed
103
104
105
106
107
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
108
            quant_config=quant_config,
张大成's avatar
张大成 committed
109
110
111
112
113
114
115
116
117
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
118
119
120
121
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads)
张大成's avatar
张大成 committed
122
123
124
125
126

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
127
128
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
129
130
131
132
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
133
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
张大成's avatar
张大成 committed
134
135
136
137
138
139
140
141
142
        output, _ = self.o_proj(attn_output)
        return output


class OrionDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
143
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = OrionAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
158
            quant_config=quant_config,
张大成's avatar
张大成 committed
159
160
161
162
163
        )
        self.mlp = OrionMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
164
            quant_config=quant_config,
张大成's avatar
张大成 committed
165
166
167
168
169
170
171
172
173
174
175
        )

        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
176
177
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
178
179
180
181
182
183
184
185
186
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
187
            attn_metadata=attn_metadata,
张大成's avatar
张大成 committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        )

        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states, None


class OrionModel(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
205
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
206
207
208
209
210
211
212
213
214
215
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
216
            OrionDecoderLayer(config, quant_config)
张大成's avatar
张大成 committed
217
218
219
220
221
222
223
224
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
225
226
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
227
228
229
230
231
232
233
234
235
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
236
                attn_metadata,
张大成's avatar
张大成 committed
237
238
239
240
241
242
243
244
245
246
247
                residual,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class OrionForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
248
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
249
250
251
    ) -> None:
        super().__init__()
        self.config = config
252
253
        self.quant_config = quant_config
        self.model = OrionModel(config, quant_config)
张大成's avatar
张大成 committed
254
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
255
256
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
张大成's avatar
张大成 committed
257
258
259
260
261

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
262
263
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
264
265
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
266
                                   attn_metadata)
张大成's avatar
张大成 committed
267
268
        return hidden_states

269
270
271
272
273
274
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

张大成's avatar
张大成 committed
275
276
    def sample(
        self,
277
        logits: torch.Tensor,
张大成's avatar
张大成 committed
278
279
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
280
        next_tokens = self.sampler(logits, sampling_metadata)
张大成's avatar
张大成 committed
281
282
        return next_tokens

283
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
张大成's avatar
张大成 committed
284
285
286
287
288
289
290
291
292
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
293
        for name, loaded_weight in weights:
张大成's avatar
张大成 committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)