internlm2.py 16.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import partial
4
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
Fengzhe Zhou's avatar
Fengzhe Zhou committed
5
6

import torch
7
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
8
9
from transformers import PretrainedConfig

10
from vllm.attention import Attention
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
14
15
16
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
17
18
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
19
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
20
21
                                               QKVParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
from vllm.model_executor.layers.pooler import Pooler, PoolingType
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
from vllm.model_executor.layers.rotary_embedding import get_rope
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
    ParallelLMHead, VocabParallelEmbedding)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.pooling_metadata import PoolingMetadata
30
from vllm.model_executor.sampling_metadata import SamplingMetadata
31
from vllm.sequence import IntermediateTensors, PoolerOutput
Fengzhe Zhou's avatar
Fengzhe Zhou committed
32

33
from .interfaces import SupportsLoRA, SupportsPP
34
from .utils import (is_pp_missing_parameter,
35
36
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
37

38
39
40
41
42
43
44
45

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
46
        quant_config: Optional[QuantizationConfig] = None,
47
        prefix: str = "",
48
49
50
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
51
52
53
54
55
56
57
58
59
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
60
            bias=False,
61
62
63
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
86
        cache_config: Optional[CacheConfig] = None,
87
        quant_config: Optional[QuantizationConfig] = None,
88
        prefix: str = "",
89
90
91
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
92
93
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
94
        self.total_num_heads = num_heads
95
96
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
97
        self.total_num_kv_heads = num_kv_heads
98
        if self.total_num_kv_heads >= self.tp_size:
99
100
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
101
            assert self.total_num_kv_heads % self.tp_size == 0
102
103
104
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
105
106
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
107
108
109
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
110
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
111
112
113
114
115
116
117
118
119
120
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
121
            quant_config=quant_config,
122
            prefix=f"{prefix}.wqkv",
123
124
125
126
127
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
128
            quant_config=quant_config,
129
            prefix=f"{prefix}.wo",
130
131
132
133
134
135
136
137
138
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
139
140
141
142
143
144
145
146
147
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
148

149
    def split_qkv(self, qkv: torch.Tensor):
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
171
172
        return q, k, v

173
174
175
176
177
178
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
179
        q, k, v = self.split_qkv(qkv)
180
        q, k = self.rotary_emb(positions, q, k)
181
        attn_output = self.attn(q, k, v)
182
183
184
185
186
187
188
189
190
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
208
            cache_config=cache_config,
209
            quant_config=quant_config,
210
            prefix=f"{prefix}.attention",
211
212
213
214
215
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
216
            quant_config=quant_config,
217
            prefix=f"{prefix}.feed_forward",
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


247
@support_torch_compile
248
class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
249

250
251
252
253
254
255
    def __init__(
            self,
            *,
            vllm_config: VllmConfig,
            prefix: str = "",
            layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
256
        super().__init__()
257
258
259
260
261

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

262
263
264
265
266
267
        self.config = config
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
268
269
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
270
            lambda prefix: layer_type(
271
                config, cache_config, quant_config, prefix=prefix),
272
            prefix=f"{prefix}.layers")
273
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
274
275
276
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
277

278
279
280
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

281
282
283
284
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
285
        intermediate_tensors: Optional[IntermediateTensors] = None,
286
        inputs_embeds: Optional[torch.Tensor] = None,
287
288
289
290
291
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
292
                hidden_states = self.get_input_embeddings(input_ids)
293
            residual = None
294
        else:
295
296
297
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
298
299
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
300
301
302
303
304
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
305
306
307
308
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


309
310
311
312
313
314
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "wqkv": ["wqkv"],
        "gate_up_proj": ["w1", "w3"],
    }

315
316
317
318
319
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 model_type: Type[InternLM2Model] = InternLM2Model):
320
        super().__init__()
321
322
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
323
324
        lora_config = vllm_config.lora_config

325
        self.config = config
326
        self.quant_config = quant_config
327
328
        self.lora_config = lora_config

329
330
        self.model = model_type(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
331
332
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
333
334
                                     quant_config=quant_config,
                                     prefix=maybe_prefix(prefix, "output"))
335
336
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
337
        self.logits_processor = LogitsProcessor(config.vocab_size)
338
339
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
340

341
342
343
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

344
345
346
347
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
348
        intermediate_tensors: Optional[IntermediateTensors],
349
        inputs_embeds: Optional[torch.Tensor] = None,
350
    ) -> torch.Tensor:
351
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
352
                                   inputs_embeds)
353
354
        return hidden_states

355
356
357
358
359
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
360
        logits = self.logits_processor(self.output, hidden_states,
361
362
363
                                       sampling_metadata)
        return logits

364
365
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415


class InternLM2ForRewardModel(InternLM2ForCausalLM):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        model_type: Type[InternLM2Model] = InternLM2Model,
    ):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         model_type=model_type)

416
        for attr in ("output", "logits_processor"):
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            delattr(self, attr)

        config = vllm_config.model_config.hf_config
        self.v_head = RowParallelLinear(
            config.hidden_size,
            1,
            bias=False,
            input_is_parallel=False,
            prefix=maybe_prefix(prefix, "v_head"),
        )

        pooler_config = vllm_config.model_config.pooler_config
        self._pooler = Pooler.from_config_with_defaults(
            pooler_config,
            pooling_type=PoolingType.ALL,
            normalize=False,
            softmax=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
443
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
444
445
446
447
448
449
450
451
452
453
                                   inputs_embeds)
        logits, _ = self.v_head(hidden_states)
        return logits

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)