internlm2.py 12.5 KB
Newer Older
Fengzhe Zhou's avatar
Fengzhe Zhou committed
1
# -*- coding: utf-8 -*-
2
from typing import Any, Dict, Iterable, List, Optional, Tuple
Fengzhe Zhou's avatar
Fengzhe Zhou committed
3
4

import torch
5
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
6
7
from transformers import PretrainedConfig

8
from vllm.attention import Attention, AttentionMetadata
9
from vllm.config import CacheConfig
10
from vllm.distributed import get_tensor_model_parallel_world_size
11
12
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
13
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
14
15
                                               QKVParallelLinear,
                                               RowParallelLinear)
16
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
19
from vllm.model_executor.layers.rotary_embedding import get_rope
20
21
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
22
    ParallelLMHead, VocabParallelEmbedding)
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.sequence import IntermediateTensors, SamplerOutput
Fengzhe Zhou's avatar
Fengzhe Zhou committed
26

27
28
29
30
31
32
33
34

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
35
        quant_config: Optional[QuantizationConfig] = None,
36
37
38
39
40
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
41
            quant_config=quant_config)
42
43
44
        self.w2 = RowParallelLinear(intermediate_size,
                                    hidden_size,
                                    bias=False,
45
                                    quant_config=quant_config)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
68
        cache_config: Optional[CacheConfig] = None,
69
        quant_config: Optional[QuantizationConfig] = None,
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
90
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
91
92
93
94
95
96
97
98
99
100
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
101
            quant_config=quant_config,
102
103
104
105
106
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
107
            quant_config=quant_config,
108
109
110
111
112
113
114
115
116
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
117
118
119
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
120
                              num_kv_heads=self.num_kv_heads,
121
122
                              cache_config=cache_config,
                              quant_config=quant_config)
123

124
125
126
127
128
129
130
131
    def split_qkv(self, qkv: torch.Tensor):
        qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
        q = q.reshape(-1, self.q_size)
        k = k.reshape(-1, self.kv_size)
        v = v.reshape(-1, self.kv_size)
        return q, k, v

132
133
134
135
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
136
137
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
138
139
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
140
        q, k, v = self.split_qkv(qkv)
141
        q, k = self.rotary_emb(positions, q, k)
142
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
143
144
145
146
147
148
149
150
151
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
152
        cache_config: Optional[CacheConfig] = None,
153
        quant_config: Optional[QuantizationConfig] = None,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
168
            cache_config=cache_config,
169
            quant_config=quant_config,
170
171
172
173
174
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
175
            quant_config=quant_config,
176
177
178
179
180
181
182
183
184
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
185
186
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
187
188
189
190
191
192
193
194
195
196
197
198
199
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
200
            attn_metadata=attn_metadata,
201
202
203
204
205
206
207
208
209
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
210
211
212

    def __init__(
        self,
213
        config: PretrainedConfig,
214
        cache_config: Optional[CacheConfig] = None,
215
        quant_config: Optional[QuantizationConfig] = None,
Fengzhe Zhou's avatar
Fengzhe Zhou committed
216
    ) -> None:
217
218
219
220
221
222
223
224
225
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
226
            InternLMDecoderLayer(config, cache_config, quant_config)
227
228
229
230
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

231
232
233
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

234
235
236
237
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
238
239
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
240
241
        intermediate_tensors: IntermediateTensors = None,
        inputs_embeds: Optional[torch.Tensor] = None,
242
    ) -> torch.Tensor:
243
244
245
246
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.tok_embeddings(input_ids)
247
248
249
250
251
252
253
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
254
                attn_metadata,
255
256
257
258
259
260
261
262
263
264
265
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class InternLM2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
266
        cache_config: Optional[CacheConfig] = None,
267
        quant_config: Optional[QuantizationConfig] = None,
268
269
270
    ) -> None:
        super().__init__()
        self.config = config
271
        self.quant_config = quant_config
272
        self.model = InternLM2Model(config, cache_config, quant_config)
273
274
275
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
                                     quant_config=quant_config)
276
277
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
278
279
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
280
281
282
283
284

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
285
286
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
287
        intermediate_tensors: IntermediateTensors,
288
289
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
290
                                   attn_metadata)
291
292
        return hidden_states

293
294
295
296
297
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
298
        logits = self.logits_processor(self.output, hidden_states,
299
300
301
                                       sampling_metadata)
        return logits

302
303
    def sample(
        self,
304
        logits: torch.Tensor,
305
306
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
307
        next_tokens = self.sampler(logits, sampling_metadata)
308
        return next_tokens
Fengzhe Zhou's avatar
Fengzhe Zhou committed
309

310
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
311
312
313
314
315
316
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
317
        for name, loaded_weight in weights:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
336
337
338
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)