"vscode:/vscode.git/clone" did not exist on "905411f123f65442794f7e985928e2728f9a08ce"
internlm2.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from functools import partial
6
from itertools import islice
7
from typing import Any
Fengzhe Zhou's avatar
Fengzhe Zhou committed
8
9

import torch
10
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
11
12
from transformers import PretrainedConfig

13
from vllm.attention.layer import Attention
14
from vllm.compilation.decorators import support_torch_compile
15
from vllm.config import CacheConfig, VllmConfig
16
17
18
19
20
21
22
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
)
23
24
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.layers.rotary_embedding import get_rope
34
from vllm.model_executor.layers.vocab_parallel_embedding import (
35
36
37
    ParallelLMHead,
    VocabParallelEmbedding,
)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.sequence import IntermediateTensors
Fengzhe Zhou's avatar
Fengzhe Zhou committed
40

41
42
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
43
from .utils import (
44
    StageMissingLayer,
45
46
47
48
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
49
    no_init_weights,
50
)
51

52
53
54
55
56
57
58

class InternLM2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
59
        quant_config: QuantizationConfig | None = None,
60
        prefix: str = "",
61
62
63
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
64
65
66
67
68
69
70
71
72
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
73
            bias=False,
74
75
76
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
77
        if hidden_act != "silu":
78
79
80
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
96
        rope_parameters: dict[str, Any] | None = None,
97
        max_position_embeddings: int = 8192,
98
99
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
100
        prefix: str = "",
101
102
103
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
104
105
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
106
        self.total_num_heads = num_heads
107
108
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
109
        self.total_num_kv_heads = num_kv_heads
110
        if self.total_num_kv_heads >= self.tp_size:
111
112
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
113
            assert self.total_num_kv_heads % self.tp_size == 0
114
115
116
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
117
118
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
119
120
121
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
122
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
123
124
125
126
127
128
129
130
131
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
132
            quant_config=quant_config,
133
            prefix=f"{prefix}.wqkv",
134
135
136
137
138
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
139
            quant_config=quant_config,
140
            prefix=f"{prefix}.wo",
141
142
143
144
145
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
146
            rope_parameters=rope_parameters,
147
        )
148
149
150
151
152
153
154
155
156
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
157

158
    def split_qkv(self, qkv: torch.Tensor):
159
160
161
162
163
164
165
166
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

167
168
169
        qkv = qkv.view(
            seq_len, self.total_num_kv_heads, self.key_value_groups + 2, self.head_dim
        )
170
171
172
173
174
175
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
176
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
177
178
179
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
180
181
        return q, k, v

182
183
184
185
186
187
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
188
        q, k, v = self.split_qkv(qkv)
189
        q, k = self.rotary_emb(positions, q, k)
190
        attn_output = self.attn(q, k, v)
191
192
193
194
195
196
197
198
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
199
200
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
201
        prefix: str = "",
202
203
204
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
205
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
206
207
208
209
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
210
            rope_parameters=config.rope_parameters,
211
            max_position_embeddings=max_position_embeddings,
212
            cache_config=cache_config,
213
            quant_config=quant_config,
214
            prefix=f"{prefix}.attention",
215
216
217
218
219
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
220
            quant_config=quant_config,
221
            prefix=f"{prefix}.feed_forward",
222
        )
223
        self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
225
226
227
228
229
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
230
        residual: torch.Tensor | None,
231
    ) -> tuple[torch.Tensor, torch.Tensor]:
232
233
234
235
236
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
237
            hidden_states, residual = self.attention_norm(hidden_states, residual)
238
239
240
241
242
243
244
245
246
247
248
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


249
@support_torch_compile
250
class InternLM2Model(nn.Module):
251
    def __init__(
252
253
254
255
256
257
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer,
    ):
258
        super().__init__()
259
260
261
262
263

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

264
265
266
267
268
269
        self.config = config
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
270
271
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
272
            lambda prefix: layer_type(
273
274
275
276
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
277
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
278
279
280
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
281

282
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
283
284
        return self.tok_embeddings(input_ids)

285
286
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
287
        input_ids: torch.Tensor,
288
        positions: torch.Tensor,
289
290
291
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
292
293
294
295
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
296
                hidden_states = self.embed_input_ids(input_ids)
297
            residual = None
298
        else:
299
300
301
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
302
        for layer in islice(self.layers, self.start_layer, self.end_layer):
303
            hidden_states, residual = layer(positions, hidden_states, residual)
304
        if not get_pp_group().is_last_rank:
305
306
307
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
308
309
310
311
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


312
313
314
315
316
317
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "wqkv": ["wqkv"],
        "gate_up_proj": ["w1", "w3"],
    }

318
319
320
321
322
323
324
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        model_type: type[InternLM2Model] = InternLM2Model,
    ):
325
        super().__init__()
326
327
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
328

329
        self.config = config
330
        self.quant_config = quant_config
331

332
333
334
335
336
337
338
339
340
        self.model = model_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.output = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "output"),
        )
341
342
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
343
        self.logits_processor = LogitsProcessor(config.vocab_size)
344
        self.make_empty_intermediate_tensors = (
345
346
            self.model.make_empty_intermediate_tensors
        )
347

348
349
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
350

351
352
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
353
        input_ids: torch.Tensor,
354
        positions: torch.Tensor,
355
356
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
357
    ) -> torch.Tensor:
358
359
360
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
361
362
        return hidden_states

363
364
365
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
366
    ) -> torch.Tensor | None:
367
        logits = self.logits_processor(self.output, hidden_states)
368
369
        return logits

370
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
371
372
373
374
375
376
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
377
        loaded_params: set[str] = set()
378
379
380
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
381
            for param_name, weight_name, shard_id in stacked_params_mapping:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
401
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
402
403
404
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
405
406


407
@default_pooling_type(tok_pooling_type="ALL")
408
class InternLM2ForRewardModel(InternLM2ForCausalLM):
409
410
    is_pooling_model = True

411
412
413
414
415
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
416
        model_type: type[InternLM2Model] = InternLM2Model,
417
    ):
418
419
420
421
422
423
424
425
426
427
        with no_init_weights(
            self,
            lambda mod: StageMissingLayer("output", mod),
            targets=(LogitsProcessor, ParallelLMHead),
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                model_type=model_type,
            )
428
429

        config = vllm_config.model_config.hf_config
430
431
        self.head_dtype = vllm_config.model_config.head_dtype

432
433
434
435
436
437
438
439
440
        self.v_head = RowParallelLinear(
            config.hidden_size,
            1,
            bias=False,
            input_is_parallel=False,
            params_dtype=self.head_dtype,
            prefix=maybe_prefix(prefix, "v_head"),
            return_bias=False,
        )
441
442

        pooler_config = vllm_config.model_config.pooler_config
443
444
        assert pooler_config is not None

445
        self.pooler = pooler_for_token_classify(pooler_config)
446
447
448

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
449
        input_ids: torch.Tensor,
450
        positions: torch.Tensor,
451
452
453
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
454
455
456
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
457
458
        hidden_states = hidden_states.to(self.head_dtype)
        logits = self.v_head(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
459
        return logits