Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8a0cf1dd
Unverified
Commit
8a0cf1dd
authored
Sep 14, 2024
by
ywfang
Committed by
GitHub
Sep 14, 2024
Browse files
[Model] support minicpm3 (#8297)
Co-authored-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
1ef0d2ef
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
281 additions
and
37 deletions
+281
-37
.buildkite/run-cpu-test.sh
.buildkite/run-cpu-test.sh
+1
-1
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
requirements-test.txt
requirements-test.txt
+1
-0
tests/models/decoder_only/language/test_big_models.py
tests/models/decoder_only/language/test_big_models.py
+9
-6
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+49
-30
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+216
-0
No files found.
.buildkite/run-cpu-test.sh
View file @
8a0cf1dd
...
@@ -22,7 +22,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
...
@@ -22,7 +22,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
# Run basic model test
docker
exec
cpu-test bash
-c
"
docker
exec
cpu-test bash
-c
"
pip install pytest matplotlib einops transformers_stream_generator
pip install pytest matplotlib einops transformers_stream_generator
datamodel_code_generator
pytest -v -s tests/models/decoder_only/language
\
pytest -v -s tests/models/decoder_only/language
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/test_fp8.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
--ignore=tests/models/decoder_only/language/test_jamba.py
\
...
...
docs/source/models/supported_models.rst
View file @
8a0cf1dd
...
@@ -107,6 +107,10 @@ Decoder-only Language Models
...
@@ -107,6 +107,10 @@ Decoder-only Language Models
- MiniCPM
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
-
-
* - :code:`MiniCPM3ForCausalLM`
- MiniCPM3
- :code:`openbmb/MiniCPM3-4B`, etc.
-
* - :code:`MistralForCausalLM`
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
...
...
requirements-test.txt
View file @
8a0cf1dd
...
@@ -21,6 +21,7 @@ compressed-tensors==0.4.0 # required for compressed-tensors
...
@@ -21,6 +21,7 @@ compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
matplotlib # required for qwen-vl test
datamodel_code_generator # required for minicpm3 test
# TODO: Add this after fully implementing llava(mantis)
# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
...
...
tests/models/decoder_only/language/test_big_models.py
View file @
8a0cf1dd
...
@@ -5,7 +5,8 @@ This tests bigger models and use half precision.
...
@@ -5,7 +5,8 @@ This tests bigger models and use half precision.
Run `pytest tests/models/test_big_models.py`.
Run `pytest tests/models/test_big_models.py`.
"""
"""
import
pytest
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
...utils
import
check_outputs_equal
from
...utils
import
check_outputs_equal
...
@@ -19,10 +20,12 @@ MODELS = [
...
@@ -19,10 +20,12 @@ MODELS = [
# "Qwen/Qwen1.5-0.5B" # Broken,
# "Qwen/Qwen1.5-0.5B" # Broken,
]
]
if
not
current_platform
.
is_cpu
():
# MiniCPM requires fused_moe which is not supported by CPU
MODELS
.
append
(
"openbmb/MiniCPM3-4B"
)
#TODO: remove this after CPU float16 support ready
#TODO: remove this after CPU float16 support ready
target_dtype
=
"float"
target_dtype
=
"float"
if
current_platform
.
is_cpu
()
else
"half"
if
torch
.
cuda
.
is_available
():
target_dtype
=
"half"
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
...
@@ -39,7 +42,7 @@ def test_models(
...
@@ -39,7 +42,7 @@ def test_models(
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
check_outputs_equal
(
...
@@ -57,7 +60,7 @@ def test_model_print(
...
@@ -57,7 +60,7 @@ def test_model_print(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
)
->
None
:
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
True
)
as
vllm_model
:
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
...
...
vllm/model_executor/models/__init__.py
View file @
8a0cf1dd
...
@@ -43,6 +43,7 @@ _GENERATION_MODELS = {
...
@@ -43,6 +43,7 @@ _GENERATION_MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPM3ForCausalLM"
:
(
"minicpm3"
,
"MiniCPM3ForCausalLM"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
...
...
vllm/model_executor/models/minicpm.py
View file @
8a0cf1dd
...
@@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
self
.
max_position_embeddings
=
getattr
(
config
,
8192
)
"max_position_embeddings"
,
8192
)
self
.
_init_attn_block
()
self
.
_init_ffn_block
()
def
_init_attn_block
(
self
):
self
.
input_layernorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
self_attn
=
MiniCPMAttention
(
self
.
self_attn
=
MiniCPMAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
self
.
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
self
.
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
self
.
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
cache_config
=
cache_config
,
cache_config
=
self
.
cache_config
,
quant_config
=
quant_config
,
quant_config
=
self
.
quant_config
,
)
)
def
_init_ffn_block
(
self
):
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
if
self
.
num_experts
==
0
:
if
self
.
num_experts
==
0
:
self
.
mlp
=
MiniCPMMLP
(
self
.
mlp
=
MiniCPMMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
self
.
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
self
.
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
self
.
quant_config
,
)
)
else
:
else
:
self
.
mlp
=
MiniCPMMoE
(
num_experts
=
config
.
num_experts
,
self
.
mlp
=
MiniCPMMoE
(
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
self
.
config
.
num_experts
,
hidden_size
=
config
.
hidden_size
,
top_k
=
self
.
config
.
num_experts_per_tok
,
intermediate_size
=
config
.
intermediate_size
)
hidden_size
=
self
.
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
intermediate_size
=
self
.
config
.
intermediate_size
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module):
...
@@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
@@ -354,11 +365,15 @@ class MiniCPMModel(nn.Module):
...
@@ -354,11 +365,15 @@ class MiniCPMModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
_init_layers
()
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
_init_layers
(
self
):
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
MiniCPMDecoderLayer
(
config
,
cache_config
,
quant_config
)
MiniCPMDecoderLayer
(
self
.
config
,
self
.
cache_config
,
for
_
in
range
(
config
.
num_hidden_layers
)
self
.
quant_config
)
for
_
in
range
(
self
.
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embedding
=
self
.
embed_tokens
(
input_ids
)
embedding
=
self
.
embed_tokens
(
input_ids
)
...
@@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
_init_model
()
self
.
model
=
MiniCPMModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
)
unpadded_vocab_size
=
config
.
vocab_size
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
@@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
_init_model
(
self
):
self
.
model
=
MiniCPMModel
(
config
=
self
.
config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
lora_config
=
self
.
lora_config
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/minicpm3.py
0 → 100644
View file @
8a0cf1dd
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.models.minicpm
import
(
MiniCPMDecoderLayer
,
MiniCPMForCausalLM
,
MiniCPMModel
)
class
MiniCPM3Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
rotary_emb
=
get_rope
(
self
.
qk_rope_head_dim
,
rotary_dim
=
self
.
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_local_heads
,
self
.
qk_head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_a_proj
(
hidden_states
)
q
=
self
.
q_a_layernorm
(
q
)
q
,
_
=
self
.
q_b_proj
(
q
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
,
_
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
kv_a
,
_
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
kv
,
_
=
self
.
kv_b_proj
(
kv_a
)
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
qk_rope_head_dim
),
k_pe
.
reshape
(
-
1
,
self
.
qk_rope_head_dim
))
q_pe
=
q_pe
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
.
view
(
-
1
,
1
,
self
.
qk_rope_head_dim
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
q
=
q
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
k
=
k
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MiniCPM3DecoderLayer
(
MiniCPMDecoderLayer
):
def
_init_attn_block
(
self
):
self
.
input_layernorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
self_attn
=
MiniCPM3Attention
(
config
=
self
.
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
self
.
config
.
num_attention_heads
,
qk_nope_head_dim
=
self
.
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
config
.
qk_rope_head_dim
,
v_head_dim
=
self
.
config
.
v_head_dim
,
q_lora_rank
=
self
.
config
.
q_lora_rank
,
kv_lora_rank
=
self
.
config
.
kv_lora_rank
,
rope_theta
=
self
.
rope_theta
,
rope_scaling
=
self
.
rope_scaling
,
max_position_embeddings
=
self
.
max_position_embeddings
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
)
class
MiniCPM3Model
(
MiniCPMModel
):
def
_init_layers
(
self
):
self
.
layers
=
nn
.
ModuleList
([
MiniCPM3DecoderLayer
(
self
.
config
,
self
.
cache_config
,
self
.
quant_config
)
for
_
in
range
(
self
.
config
.
num_hidden_layers
)
])
class
MiniCPM3ForCausalLM
(
MiniCPMForCausalLM
):
def
_init_model
(
self
):
self
.
model
=
MiniCPM3Model
(
config
=
self
.
config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
lora_config
=
self
.
lora_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment