internlm2.py 14.7 KB
Newer Older
Fengzhe Zhou's avatar
Fengzhe Zhou committed
1
# -*- coding: utf-8 -*-
2
from functools import partial
3
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Fengzhe Zhou's avatar
Fengzhe Zhou committed
4
5

import torch
6
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
7
8
from transformers import PretrainedConfig

9
from vllm.attention import Attention, AttentionMetadata
10
from vllm.config import CacheConfig
11
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
12
13
14
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
15
16
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
17
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
18
19
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
22
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
23
from vllm.model_executor.layers.rotary_embedding import get_rope
24
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
25
from vllm.model_executor.layers.vocab_parallel_embedding import (
26
    ParallelLMHead, VocabParallelEmbedding)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
from vllm.model_executor.sampling_metadata import SamplingMetadata
29
from vllm.sequence import IntermediateTensors
Fengzhe Zhou's avatar
Fengzhe Zhou committed
30

31
32
33
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

34
35
36
37
38
39
40
41

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
42
        quant_config: Optional[QuantizationConfig] = None,
43
44
45
46
47
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
48
            quant_config=quant_config)
49
50
51
        self.w2 = RowParallelLinear(intermediate_size,
                                    hidden_size,
                                    bias=False,
52
                                    quant_config=quant_config)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
75
        cache_config: Optional[CacheConfig] = None,
76
        quant_config: Optional[QuantizationConfig] = None,
77
78
79
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
80
81
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
82
        self.total_num_heads = num_heads
83
84
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
85
        self.total_num_kv_heads = num_kv_heads
86
        if self.total_num_kv_heads >= self.tp_size:
87
88
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
89
            assert self.total_num_kv_heads % self.tp_size == 0
90
91
92
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
93
94
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
95
96
97
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
98
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
99
100
101
102
103
104
105
106
107
108
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
109
            quant_config=quant_config,
110
111
112
113
114
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
115
            quant_config=quant_config,
116
117
118
119
120
121
122
123
124
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
125
126
127
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
128
                              num_kv_heads=self.num_kv_heads,
129
130
                              cache_config=cache_config,
                              quant_config=quant_config)
131

132
    def split_qkv(self, qkv: torch.Tensor):
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
154
155
        return q, k, v

156
157
158
159
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
160
161
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
162
163
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
164
        q, k, v = self.split_qkv(qkv)
165
        q, k = self.rotary_emb(positions, q, k)
166
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
167
168
169
170
171
172
173
174
175
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
192
            cache_config=cache_config,
193
            quant_config=quant_config,
194
195
196
197
198
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
199
            quant_config=quant_config,
200
201
202
203
204
205
206
207
208
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
209
210
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
211
212
213
214
215
216
217
218
219
220
221
222
223
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
224
            attn_metadata=attn_metadata,
225
226
227
228
229
230
231
232
233
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
234
235
236

    def __init__(
        self,
237
        config: PretrainedConfig,
238
        cache_config: Optional[CacheConfig] = None,
239
        quant_config: Optional[QuantizationConfig] = None,
240
        prefix: str = "",
Fengzhe Zhou's avatar
Fengzhe Zhou committed
241
    ) -> None:
242
243
244
245
246
247
248
249
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
250
251
252
253
254
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: InternLMDecoderLayer(config, cache_config,
                                                quant_config),
            prefix=f"{prefix}.layers")
255
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256
257
258
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
259

260
261
262
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

263
264
265
266
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
267
268
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
269
270
        intermediate_tensors: IntermediateTensors = None,
        inputs_embeds: Optional[torch.Tensor] = None,
271
272
273
274
275
276
277
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.tok_embeddings(input_ids)
            residual = None
278
        else:
279
280
281
282
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
283
284
285
286
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
287
                kv_caches[i - self.start_layer],
288
                attn_metadata,
289
290
                residual,
            )
291
292
293
294
295
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
296
297
298
299
300
301
302
303
304
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class InternLM2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
305
        cache_config: Optional[CacheConfig] = None,
306
        quant_config: Optional[QuantizationConfig] = None,
307
308
309
    ) -> None:
        super().__init__()
        self.config = config
310
        self.quant_config = quant_config
311
        self.model = InternLM2Model(config, cache_config, quant_config)
312
313
314
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
                                     quant_config=quant_config)
315
316
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
317
318
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
319
320
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
321
322
323
324
325

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
326
327
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
328
        intermediate_tensors: IntermediateTensors,
329
330
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
331
                                   attn_metadata, intermediate_tensors)
332
333
        return hidden_states

334
335
336
337
338
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
339
        logits = self.logits_processor(self.output, hidden_states,
340
341
342
                                       sampling_metadata)
        return logits

343
344
    def sample(
        self,
345
        logits: torch.Tensor,
346
347
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
348
        next_tokens = self.sampler(logits, sampling_metadata)
349
        return next_tokens
Fengzhe Zhou's avatar
Fengzhe Zhou committed
350

351
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
352
353
354
355
356
357
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
358
        for name, loaded_weight in weights:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
359
360
361
362
363
364
365
366
367
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
368
369
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
370
371
372
373
374
375
376
377
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
378
379
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
380
                param = params_dict[name]
381
382
383
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)