internlm2.py 18.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import partial
4
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Fengzhe Zhou's avatar
Fengzhe Zhou committed
5
6

import torch
7
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
8
9
from transformers import PretrainedConfig

10
from vllm.attention import Attention, AttentionMetadata
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
14
15
16
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
17
18
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
19
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
20
21
                                               QKVParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
from vllm.model_executor.layers.pooler import Pooler, PoolingType
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
    ParallelLMHead, VocabParallelEmbedding)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
from vllm.model_executor.pooling_metadata import PoolingMetadata
31
from vllm.model_executor.sampling_metadata import SamplingMetadata
32
from vllm.sequence import IntermediateTensors, PoolerOutput
Fengzhe Zhou's avatar
Fengzhe Zhou committed
33

34
from .interfaces import SupportsLoRA, SupportsPP
35
from .utils import (is_pp_missing_parameter,
36
37
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
38

39
40
41
42
43
44
45
46

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
47
        quant_config: Optional[QuantizationConfig] = None,
48
        prefix: str = "",
49
50
51
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
52
53
54
55
56
57
58
59
60
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
61
            bias=False,
62
63
64
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
87
        cache_config: Optional[CacheConfig] = None,
88
        quant_config: Optional[QuantizationConfig] = None,
89
        prefix: str = "",
90
91
92
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
93
94
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
95
        self.total_num_heads = num_heads
96
97
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
98
        self.total_num_kv_heads = num_kv_heads
99
        if self.total_num_kv_heads >= self.tp_size:
100
101
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
102
            assert self.total_num_kv_heads % self.tp_size == 0
103
104
105
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
106
107
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
108
109
110
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
111
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
112
113
114
115
116
117
118
119
120
121
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
122
            quant_config=quant_config,
123
            prefix=f"{prefix}.wqkv",
124
125
126
127
128
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
129
            quant_config=quant_config,
130
            prefix=f"{prefix}.wo",
131
132
133
134
135
136
137
138
139
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
140
141
142
143
144
145
146
147
148
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
149

150
    def split_qkv(self, qkv: torch.Tensor):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
172
173
        return q, k, v

174
175
176
177
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
178
179
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
180
181
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
182
        q, k, v = self.split_qkv(qkv)
183
        q, k = self.rotary_emb(positions, q, k)
184
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
185
186
187
188
189
190
191
192
193
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
194
        cache_config: Optional[CacheConfig] = None,
195
        quant_config: Optional[QuantizationConfig] = None,
196
        prefix: str = "",
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
211
            cache_config=cache_config,
212
            quant_config=quant_config,
213
            prefix=f"{prefix}.attention",
214
215
216
217
218
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
219
            quant_config=quant_config,
220
            prefix=f"{prefix}.feed_forward",
221
222
223
224
225
226
227
228
229
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
230
231
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
232
233
234
235
236
237
238
239
240
241
242
243
244
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
245
            attn_metadata=attn_metadata,
246
247
248
249
250
251
252
253
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


254
@support_torch_compile
255
class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
256

257
258
259
260
261
262
    def __init__(
            self,
            *,
            vllm_config: VllmConfig,
            prefix: str = "",
            layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
263
        super().__init__()
264
265
266
267
268

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

269
270
271
272
273
274
275
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
276
277
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
278
            lambda prefix: layer_type(
279
                config, cache_config, quant_config, prefix=prefix),
280
            prefix=f"{prefix}.layers")
281
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
283
284
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
285

286
287
288
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

289
290
291
292
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
293
294
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
295
        intermediate_tensors: Optional[IntermediateTensors] = None,
296
        inputs_embeds: Optional[torch.Tensor] = None,
297
298
299
300
301
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
302
                hidden_states = self.get_input_embeddings(input_ids)
303
            residual = None
304
        else:
305
306
307
308
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
309
310
311
312
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
313
                kv_caches[i - self.start_layer],
314
                attn_metadata,
315
316
                residual,
            )
317
318
319
320
321
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
322
323
324
325
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "wqkv": ["wqkv"],
        "gate_up_proj": ["w1", "w3"],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "wqkv",
        "wo",
        "gate_up_proj",
        "w2",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
341

342
343
344
345
346
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 model_type: Type[InternLM2Model] = InternLM2Model):
347
        super().__init__()
348
349
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
350
351
        lora_config = vllm_config.lora_config

352
        self.config = config
353
        self.quant_config = quant_config
354
355
        self.lora_config = lora_config

356
357
        self.model = model_type(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
358
359
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
360
361
                                     quant_config=quant_config,
                                     prefix=maybe_prefix(prefix, "output"))
362
363
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
364
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
365
        self.sampler = get_sampler()
366
367
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
368

369
370
371
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

372
373
374
375
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
376
377
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
378
        intermediate_tensors: Optional[IntermediateTensors],
379
        inputs_embeds: Optional[torch.Tensor] = None,
380
381
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
382
383
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
384
385
        return hidden_states

386
387
388
389
390
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
391
        logits = self.logits_processor(self.output, hidden_states,
392
393
394
                                       sampling_metadata)
        return logits

395
396
    def sample(
        self,
397
        logits: torch.Tensor,
398
399
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
400
        next_tokens = self.sampler(logits, sampling_metadata)
401
        return next_tokens
Fengzhe Zhou's avatar
Fengzhe Zhou committed
402

403
404
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
405
406
407
408
409
410
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
411
        loaded_params: Set[str] = set()
412
        for name, loaded_weight in weights:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
413
414
415
416
417
418
419
420
421
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
422
423
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
424
425
426
427
428
429
430
431
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
432
433
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
434
                param = params_dict[name]
435
436
437
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
438
439
            loaded_params.add(name)
        return loaded_params
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495


class InternLM2ForRewardModel(InternLM2ForCausalLM):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        model_type: Type[InternLM2Model] = InternLM2Model,
    ):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         model_type=model_type)

        for attr in ("output", "logits_processor", "sampler"):
            delattr(self, attr)

        config = vllm_config.model_config.hf_config
        self.v_head = RowParallelLinear(
            config.hidden_size,
            1,
            bias=False,
            input_is_parallel=False,
            prefix=maybe_prefix(prefix, "v_head"),
        )

        pooler_config = vllm_config.model_config.pooler_config
        self._pooler = Pooler.from_config_with_defaults(
            pooler_config,
            pooling_type=PoolingType.ALL,
            normalize=False,
            softmax=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
        logits, _ = self.v_head(hidden_states)
        return logits

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)