minicpm.py 23.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

ywfang's avatar
ywfang committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
26
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
ywfang's avatar
ywfang committed
27
28
29

import torch
from torch import nn
30
from transformers import PretrainedConfig
ywfang's avatar
ywfang committed
31

32
from vllm.attention import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
36
37
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
38
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
ywfang's avatar
ywfang committed
39
40
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
41
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ywfang's avatar
ywfang committed
42
43
44
45
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
ywfang's avatar
ywfang committed
47
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
48
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
ywfang's avatar
ywfang committed
49
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
51
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
ywfang's avatar
ywfang committed
52
53
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
54
from vllm.platforms import current_platform
55
from vllm.sequence import IntermediateTensors
ywfang's avatar
ywfang committed
56

57
from .interfaces import SupportsLoRA, SupportsPP
58
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
61

ywfang's avatar
ywfang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

class MiniCPMMoE(nn.Module):
    """A tensor-parallel MoE implementation that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
        tp_size: Optional[int] = None,
    ):
        super().__init__()
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size // self.tp_size

        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.gate = ReplicatedLinear(self.hidden_size,
                                     self.num_total_experts,
                                     bias=False,
                                     params_dtype=self.params_dtype,
96
                                     quant_config=None)
ywfang's avatar
ywfang committed
97
98
99
100
101

        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
102
                        device=current_platform.device_type,
ywfang's avatar
ywfang committed
103
104
105
106
107
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
108
                        device=current_platform.device_type,
ywfang's avatar
ywfang committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                        dtype=self.params_dtype))

        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
                                        router_logits,
                                        self.top_k,
                                        renormalize=True,
                                        inplace=True)

        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)

        return final_hidden_states.view(num_tokens, hidden_size)


class MiniCPMMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
159
        hidden_act_param: float,
160
        quant_config: Optional[QuantizationConfig] = None,
ywfang's avatar
ywfang committed
161
162
163
164
165
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
166
            quant_config=quant_config)
ywfang's avatar
ywfang committed
167
168
169
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
170
                                           quant_config=quant_config)
171
172
173
174
175
        if hidden_act == "silu":
            self.act_fn = SiluAndMul()
        elif hidden_act == "fatrelu":
            self.act_fn = FatreluAndMul(threshold=hidden_act_param)
        else:
ywfang's avatar
ywfang committed
176
            raise ValueError(f"Unsupported activation: {hidden_act}. "
177
                             "Only silu and fatrelu are supported for now.")
ywfang's avatar
ywfang committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class MiniCPMAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
196
        cache_config: Optional[CacheConfig] = None,
197
        quant_config: Optional[QuantizationConfig] = None,
198
        prefix: str = "",
ywfang's avatar
ywfang committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
229
            quant_config=quant_config,
ywfang's avatar
ywfang committed
230
231
232
233
234
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
235
            quant_config=quant_config,
ywfang's avatar
ywfang committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        # set rope as fp32 instead of bf16
        self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
251
                              num_kv_heads=self.num_kv_heads,
252
                              cache_config=cache_config,
253
254
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
ywfang's avatar
ywfang committed
255
256
257
258
259
260
261
262
263
264
265
266

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        orig_dtype = q.dtype
        q, k = q.float(), k.float()
        q, k = self.rotary_emb(positions, q, k)
        q, k = q.to(orig_dtype), k.to(orig_dtype)
267
        attn_output = self.attn(q, k, v)
ywfang's avatar
ywfang committed
268
269
270
271
272
273
274
275
        output, _ = self.o_proj(attn_output)
        return output


class MiniCPMDecoderLayer(nn.Module):

    def __init__(
        self,
276
        config: PretrainedConfig,
277
        cache_config: Optional[CacheConfig] = None,
278
        quant_config: Optional[QuantizationConfig] = None,
279
        prefix: str = "",
ywfang's avatar
ywfang committed
280
281
282
    ) -> None:
        super().__init__()
        self.config = config
ywfang's avatar
ywfang committed
283
284
        self.cache_config = cache_config
        self.quant_config = quant_config
ywfang's avatar
ywfang committed
285
        self.hidden_size = config.hidden_size
ywfang's avatar
ywfang committed
286
287
288
289
        self.rope_theta = getattr(config, "rope_theta", 10000)
        self.rope_scaling = getattr(config, "rope_scaling", None)
        self.max_position_embeddings = getattr(config,
                                               "max_position_embeddings", 8192)
290
        self.prefix = prefix
ywfang's avatar
ywfang committed
291
292
293
294
295
296
        self._init_attn_block()
        self._init_ffn_block()

    def _init_attn_block(self):
        self.input_layernorm = RMSNorm(self.config.hidden_size,
                                       eps=self.config.rms_norm_eps)
ywfang's avatar
ywfang committed
297
298
        self.self_attn = MiniCPMAttention(
            hidden_size=self.hidden_size,
ywfang's avatar
ywfang committed
299
300
301
302
303
304
305
            num_heads=self.config.num_attention_heads,
            num_kv_heads=self.config.num_key_value_heads,
            rope_theta=self.rope_theta,
            rope_scaling=self.rope_scaling,
            max_position_embeddings=self.max_position_embeddings,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
306
            prefix=f"{self.prefix}.self_attn",
ywfang's avatar
ywfang committed
307
        )
ywfang's avatar
ywfang committed
308
309
310
311

    def _init_ffn_block(self):
        self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
                                                eps=self.config.rms_norm_eps)
ywfang's avatar
ywfang committed
312
313
314
315
        self.num_experts = getattr(self.config, "num_experts", 0)
        if self.num_experts == 0:
            self.mlp = MiniCPMMLP(
                hidden_size=self.hidden_size,
ywfang's avatar
ywfang committed
316
317
                intermediate_size=self.config.intermediate_size,
                hidden_act=self.config.hidden_act,
318
                hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
ywfang's avatar
ywfang committed
319
                quant_config=self.quant_config,
ywfang's avatar
ywfang committed
320
321
            )
        else:
ywfang's avatar
ywfang committed
322
323
324
325
326
            self.mlp = MiniCPMMoE(
                num_experts=self.config.num_experts,
                top_k=self.config.num_experts_per_tok,
                hidden_size=self.config.hidden_size,
                intermediate_size=self.config.intermediate_size)
ywfang's avatar
ywfang committed
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * \
            (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states * \
            (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))

        return hidden_states, None


354
@support_torch_compile
ywfang's avatar
ywfang committed
355
356
class MiniCPMModel(nn.Module):

357
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
ywfang's avatar
ywfang committed
358
        super().__init__()
359
360
361
362
363
364

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

ywfang's avatar
ywfang committed
365
        self.config = config
ywfang's avatar
ywfang committed
366
367
        self.cache_config = cache_config
        self.quant_config = quant_config
ywfang's avatar
ywfang committed
368
369
370
371
372
373
374
375
376
377
        self.padding_idx = config.pad_token_id
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )
378
        self.num_experts = getattr(self.config, "num_experts", 0)
379
        self._init_layers(prefix, config, cache_config, quant_config)
ywfang's avatar
ywfang committed
380
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
382
383
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], self.config.hidden_size))
ywfang's avatar
ywfang committed
384

385
386
387
388
389
390
391
392
393
    def _init_layers(
        self,
        prefix: str,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig],
        quant_config: Optional[QuantizationConfig],
    ):
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
394
395
            lambda prefix: MiniCPMDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
396
            prefix=f"{prefix}.layers")
ywfang's avatar
ywfang committed
397
398
399
400
401
402
403
404
405

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        embedding = self.embed_tokens(input_ids)
        return embedding * self.config.scale_emb

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Alphi's avatar
Alphi committed
406
        intermediate_tensors: Optional[IntermediateTensors] = None,
ywfang's avatar
ywfang committed
407
        inputs_embeds: Optional[torch.Tensor] = None,
408
409
410
411
412
413
414
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
ywfang's avatar
ywfang committed
415
        else:
416
417
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
ywfang's avatar
ywfang committed
418

419
        for layer in self.layers[self.start_layer:self.end_layer]:
ywfang's avatar
ywfang committed
420
421
422
423
424
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
425
426
427
428
429
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
ywfang's avatar
ywfang committed
430
431
432
        hidden_states = self.norm(hidden_states)
        return hidden_states

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.num_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

ywfang's avatar
ywfang committed
500

501
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
ywfang's avatar
ywfang committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

521
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
ywfang's avatar
ywfang committed
522
        super().__init__()
523
524
525
526
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
527

528
529
        self.prefix = prefix
        self.vllm_config = vllm_config
ywfang's avatar
ywfang committed
530
        self.config = config
531
        self.lora_config = lora_config
ywfang's avatar
ywfang committed
532
533
        self.cache_config = cache_config
        self.quant_config = quant_config
534

535
536
537
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))

ywfang's avatar
ywfang committed
538
539
540
        unpadded_vocab_size = config.vocab_size
        if lora_config:
            unpadded_vocab_size += lora_config.lora_extra_vocab_size
541
542
543
544
545
546
547
548
549
550
551
552
        self.lm_head = ParallelLMHead(
            unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
ywfang's avatar
ywfang committed
553
554
555
556
        self.scale_width = self.config.hidden_size / self.config.dim_model_base

        self.logits_processor = LogitsProcessor(unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
557
        self.sampler = get_sampler()
558
559
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
ywfang's avatar
ywfang committed
560

561
    def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
562
        return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
ywfang's avatar
ywfang committed
563

564
565
566
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

ywfang's avatar
ywfang committed
567
568
569
570
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
571
        intermediate_tensors: Optional[IntermediateTensors] = None,
572
        inputs_embeds: Optional[torch.Tensor] = None,
573
    ) -> Union[torch.Tensor, IntermediateTensors]:
574
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
575
                                   inputs_embeds)
ywfang's avatar
ywfang committed
576
577
        return hidden_states

578
579
580
581
582
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
ywfang's avatar
ywfang committed
583
        hidden_states = hidden_states / self.scale_width
584
        logits = self.logits_processor(self.lm_head, hidden_states,
ywfang's avatar
ywfang committed
585
586
587
588
589
590
591
592
593
594
595
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

596
597
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
598
599
600
601
602
603
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)