internlm2.py 15.2 KB
Newer Older
1
from functools import partial
2
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Fengzhe Zhou's avatar
Fengzhe Zhou committed
3
4

import torch
5
from torch import nn
Fengzhe Zhou's avatar
Fengzhe Zhou committed
6
7
from transformers import PretrainedConfig

8
from vllm.attention import Attention, AttentionMetadata
9
from vllm.compilation.decorators import support_torch_compile
10
from vllm.config import CacheConfig, VllmConfig
11
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
12
13
14
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
15
16
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
17
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
18
19
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.layers.vocab_parallel_embedding import (
25
    ParallelLMHead, VocabParallelEmbedding)
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.sequence import IntermediateTensors
Fengzhe Zhou's avatar
Fengzhe Zhou committed
29

30
from .interfaces import SupportsPP
31
from .utils import (is_pp_missing_parameter,
32
33
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
34

35
36
37
38
39
40
41
42

class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
43
        quant_config: Optional[QuantizationConfig] = None,
44
        prefix: str = "",
45
46
47
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
48
49
50
51
52
53
54
55
56
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
57
            bias=False,
58
59
60
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x


class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
83
        cache_config: Optional[CacheConfig] = None,
84
        quant_config: Optional[QuantizationConfig] = None,
85
        prefix: str = "",
86
87
88
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
89
90
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
91
        self.total_num_heads = num_heads
92
93
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
94
        self.total_num_kv_heads = num_kv_heads
95
        if self.total_num_kv_heads >= self.tp_size:
96
97
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
98
            assert self.total_num_kv_heads % self.tp_size == 0
99
100
101
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
102
103
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
104
105
106
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
107
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
108
109
110
111
112
113
114
115
116
117
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
118
            quant_config=quant_config,
119
            prefix=f"{prefix}.wqkv",
120
121
122
123
124
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
125
            quant_config=quant_config,
126
            prefix=f"{prefix}.wo",
127
128
129
130
131
132
133
134
135
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
136
137
138
139
140
141
142
143
144
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
145

146
    def split_qkv(self, qkv: torch.Tensor):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
168
169
        return q, k, v

170
171
172
173
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
174
175
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
176
177
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
178
        q, k, v = self.split_qkv(qkv)
179
        q, k = self.rotary_emb(positions, q, k)
180
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
181
182
183
184
185
186
187
188
189
        output, _ = self.wo(attn_output)
        return output


class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
190
        cache_config: Optional[CacheConfig] = None,
191
        quant_config: Optional[QuantizationConfig] = None,
192
        prefix: str = "",
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
207
            cache_config=cache_config,
208
            quant_config=quant_config,
209
            prefix=f"{prefix}.attention",
210
211
212
213
214
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
215
            quant_config=quant_config,
216
            prefix=f"{prefix}.feed_forward",
217
218
219
220
221
222
223
224
225
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
226
227
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
228
229
230
231
232
233
234
235
236
237
238
239
240
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
241
            attn_metadata=attn_metadata,
242
243
244
245
246
247
248
249
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


250
@support_torch_compile
251
class InternLM2Model(nn.Module):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
252

253
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
254
        super().__init__()
255
256
257
258
259

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

260
261
262
263
264
265
266
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
267
268
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
269
270
            lambda prefix: InternLMDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
271
            prefix=f"{prefix}.layers")
272
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
274
275
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
276

277
278
279
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

280
281
282
283
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
284
285
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
286
        intermediate_tensors: Optional[IntermediateTensors] = None,
287
        inputs_embeds: Optional[torch.Tensor] = None,
288
289
290
291
292
293
294
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.tok_embeddings(input_ids)
            residual = None
295
        else:
296
297
298
299
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
300
301
302
303
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
304
                kv_caches[i - self.start_layer],
305
                attn_metadata,
306
307
                residual,
            )
308
309
310
311
312
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
313
314
315
316
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


317
class InternLM2ForCausalLM(nn.Module, SupportsPP):
318

319
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
320
        super().__init__()
321
322
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
323
        self.config = config
324
        self.quant_config = quant_config
325
        self.model = InternLM2Model(vllm_config=vllm_config,
326
                                    prefix=maybe_prefix(prefix, "model"))
327
328
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
329
330
                                     quant_config=quant_config,
                                     prefix=maybe_prefix(prefix, "output"))
331
332
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
333
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
334
        self.sampler = get_sampler()
335
336
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
337
338
339
340
341

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
342
343
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
344
        intermediate_tensors: Optional[IntermediateTensors],
345
346
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
347
                                   attn_metadata, intermediate_tensors)
348
349
        return hidden_states

350
351
352
353
354
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
355
        logits = self.logits_processor(self.output, hidden_states,
356
357
358
                                       sampling_metadata)
        return logits

359
360
    def sample(
        self,
361
        logits: torch.Tensor,
362
363
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
364
        next_tokens = self.sampler(logits, sampling_metadata)
365
        return next_tokens
Fengzhe Zhou's avatar
Fengzhe Zhou committed
366

367
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
368
369
370
371
372
373
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
374
        for name, loaded_weight in weights:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
375
376
377
378
379
380
381
382
383
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
384
385
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
386
387
388
389
390
391
392
393
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
394
395
                if is_pp_missing_parameter(name, self):
                    continue
Fengzhe Zhou's avatar
Fengzhe Zhou committed
396
                param = params_dict[name]
397
398
399
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)