internlm2.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from functools import partial
6
from itertools import islice
7
from typing import Any, Optional, Union
Fengzhe Zhou's avatar
Fengzhe Zhou committed
8
9

import torch
10
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
11
12
from transformers import PretrainedConfig

13
from vllm.attention import Attention
14
from vllm.compilation.decorators import support_torch_compile
15
from vllm.config import CacheConfig, VllmConfig
16
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
17
18
19
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
20
21
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23
24
                                               QKVParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
27
from vllm.model_executor.layers.quantization import QuantizationConfig
28
from vllm.model_executor.layers.rotary_embedding import get_rope
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.sequence import IntermediateTensors
Fengzhe Zhou's avatar
Fengzhe Zhou committed
33

34
35
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
36
from .utils import (is_pp_missing_parameter,
37
38
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
39

40
41
42
43
44
45
46
47

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
48
        quant_config: Optional[QuantizationConfig] = None,
49
        prefix: str = "",
50
51
52
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
53
54
55
56
57
58
59
60
61
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
62
            bias=False,
63
64
65
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
86
        rope_scaling: Optional[dict[str, Any]] = None,
87
        max_position_embeddings: int = 8192,
88
        cache_config: Optional[CacheConfig] = None,
89
        quant_config: Optional[QuantizationConfig] = None,
90
        prefix: str = "",
91
92
93
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
94
95
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
96
        self.total_num_heads = num_heads
97
98
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
99
        self.total_num_kv_heads = num_kv_heads
100
        if self.total_num_kv_heads >= self.tp_size:
101
102
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
103
            assert self.total_num_kv_heads % self.tp_size == 0
104
105
106
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
107
108
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
109
110
111
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
112
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
113
114
115
116
117
118
119
120
121
122
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
123
            quant_config=quant_config,
124
            prefix=f"{prefix}.wqkv",
125
126
127
128
129
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
130
            quant_config=quant_config,
131
            prefix=f"{prefix}.wo",
132
133
134
135
136
137
138
139
140
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
141
142
143
144
145
146
147
148
149
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
150

151
    def split_qkv(self, qkv: torch.Tensor):
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
173
174
        return q, k, v

175
176
177
178
179
180
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
181
        q, k, v = self.split_qkv(qkv)
182
        q, k = self.rotary_emb(positions, q, k)
183
        attn_output = self.attn(q, k, v)
184
185
186
187
188
189
190
191
192
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
193
        cache_config: Optional[CacheConfig] = None,
194
        quant_config: Optional[QuantizationConfig] = None,
195
        prefix: str = "",
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
210
            cache_config=cache_config,
211
            quant_config=quant_config,
212
            prefix=f"{prefix}.attention",
213
214
215
216
217
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
218
            quant_config=quant_config,
219
            prefix=f"{prefix}.feed_forward",
220
221
222
223
224
225
226
227
228
229
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
230
    ) -> tuple[torch.Tensor, torch.Tensor]:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


249
@support_torch_compile
250
class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
251

252
253
254
255
256
    def __init__(
            self,
            *,
            vllm_config: VllmConfig,
            prefix: str = "",
257
            layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
258
        super().__init__()
259
260
261
262
263

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

264
265
266
267
268
269
        self.config = config
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
270
271
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
272
            lambda prefix: layer_type(
273
                config, cache_config, quant_config, prefix=prefix),
274
            prefix=f"{prefix}.layers")
275
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
277
278
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
279

280
281
282
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

283
284
285
286
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
287
        intermediate_tensors: Optional[IntermediateTensors] = None,
288
        inputs_embeds: Optional[torch.Tensor] = None,
289
290
291
292
293
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
294
                hidden_states = self.get_input_embeddings(input_ids)
295
            residual = None
296
        else:
297
298
299
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
300
        for layer in islice(self.layers, self.start_layer, self.end_layer):
301
            hidden_states, residual = layer(positions, hidden_states, residual)
302
303
304
305
306
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
307
308
309
310
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


311
312
313
314
315
316
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "wqkv": ["wqkv"],
        "gate_up_proj": ["w1", "w3"],
    }

317
318
319
320
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
321
                 model_type: type[InternLM2Model] = InternLM2Model):
322
        super().__init__()
323
324
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
325
326
        lora_config = vllm_config.lora_config

327
        self.config = config
328
        self.quant_config = quant_config
329
330
        self.lora_config = lora_config

331
332
        self.model = model_type(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
333
334
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
335
336
                                     quant_config=quant_config,
                                     prefix=maybe_prefix(prefix, "output"))
337
338
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
339
        self.logits_processor = LogitsProcessor(config.vocab_size)
340
341
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
342

343
344
345
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

346
347
348
349
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
350
        intermediate_tensors: Optional[IntermediateTensors],
351
        inputs_embeds: Optional[torch.Tensor] = None,
352
    ) -> torch.Tensor:
353
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
354
                                   inputs_embeds)
355
356
        return hidden_states

357
358
359
360
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
361
        logits = self.logits_processor(self.output, hidden_states)
362
363
        return logits

364
365
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
366
367
368
369
370
371
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
372
        loaded_params: set[str] = set()
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
401
402


403
@default_pooling_type("ALL")
404
405
class InternLM2ForRewardModel(InternLM2ForCausalLM):

406
407
    is_pooling_model = True

408
409
410
411
412
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
413
        model_type: type[InternLM2Model] = InternLM2Model,
414
415
416
417
418
    ):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         model_type=model_type)

419
        for attr in ("output", "logits_processor"):
420
421
422
            delattr(self, attr)

        config = vllm_config.model_config.hf_config
423
424
425
426
427
428
429
430
431
        self.head_dtype = vllm_config.model_config.head_dtype

        self.v_head = RowParallelLinear(config.hidden_size,
                                        1,
                                        bias=False,
                                        input_is_parallel=False,
                                        params_dtype=self.head_dtype,
                                        prefix=maybe_prefix(prefix, "v_head"),
                                        return_bias=False)
432
433

        pooler_config = vllm_config.model_config.pooler_config
434
435
436
437
        assert pooler_config is not None

        self.pooler = DispatchPooler(
            {"encode": Pooler.for_encode(pooler_config)}, )
438
439
440
441
442
443
444
445

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
446
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
447
                                   inputs_embeds)
448
449
        hidden_states = hidden_states.to(self.head_dtype)
        logits = self.v_head(hidden_states)
450
        return logits