Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e72275cf
Unverified
Commit
e72275cf
authored
Sep 10, 2024
by
William
Committed by
GitHub
Sep 10, 2024
Browse files
Support MiniCPM3 (#1371)
parent
fec2d122
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
683 additions
and
2 deletions
+683
-2
README.md
README.md
+1
-0
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+4
-1
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+4
-1
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+5
-0
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+669
-0
No files found.
README.md
View file @
e72275cf
...
...
@@ -259,6 +259,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
-
ChatGLM
-
InternLM 2
-
Exaone 3
-
MiniCPM / MiniCPM 3
**Embedding Models**
...
...
python/sglang/srt/layers/decode_attention.py
View file @
e72275cf
...
...
@@ -483,11 +483,14 @@ def _decode_grouped_att_m_fwd(
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
}
assert
Lk
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
,
288
}
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
...
...
python/sglang/srt/layers/extend_attention.py
View file @
e72275cf
...
...
@@ -280,12 +280,15 @@ def extend_attention_fwd(
assert
Lq
==
Lk
and
Lv
==
Lo
# TODO: is the assertion necessary?
assert
Lq
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
}
assert
Lq
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
,
288
}
assert
Lv
in
{
16
,
32
,
64
,
96
,
128
,
256
,
512
}
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lq
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
...
...
python/sglang/srt/model_config.py
View file @
e72275cf
...
...
@@ -64,6 +64,11 @@ class ModelConfig:
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
elif
"MiniCPM3ForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
head_dim
=
128
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
else
:
self
.
attention_arch
=
AttentionArch
.
MHA
...
...
python/sglang/srt/models/minicpm3.py
0 → 100644
View file @
e72275cf
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
flashinfer
import
bmm_fp8
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
class
MiniCPM3MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
def
input_to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
class
MiniCPM3Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
,
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
# TODO support head_size 96
self
.
attn
=
RadixAttention
(
self
.
num_local_heads
,
128
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
original_shapes
=
[
q_pe
.
shape
,
k_pe
.
shape
]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
.
reshape
(
q_pe
.
shape
[
0
],
-
1
),
k_pe
.
reshape
(
k_pe
.
shape
[
0
],
-
1
)
)
q_pe
,
k_pe
=
q_pe
.
view
(
original_shapes
[
0
]),
k_pe
.
view
(
original_shapes
[
1
])
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
q
=
torch
.
nn
.
functional
.
pad
(
q
,
[
0
,
128
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
128
)
k
=
torch
.
nn
.
functional
.
pad
(
k
,
[
0
,
128
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
128
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
128
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
128
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
128
)[
...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MiniCPM3AttentionMLA
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
,
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
RadixAttention
(
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
scaling
,
num_kv_heads
=
1
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
kv_lora_rank
,
)
self
.
w_kc
=
None
self
.
w_vc
=
None
self
.
w_scale
=
None
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_len
,
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
input_to_float8
(
q_nope
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
)
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
)
else
:
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
)
q_input
[...,
:
self
.
kv_lora_rank
]
=
q_nope_out
.
transpose
(
0
,
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
v_input
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
v_input
=
self
.
kv_a_layernorm
(
v_input
.
contiguous
()).
unsqueeze
(
1
)
k_input
=
latent_cache
.
unsqueeze
(
1
)
k_input
[...,
:
self
.
kv_lora_rank
]
=
v_input
k_pe
=
k_input
[...,
self
.
kv_lora_rank
:]
original_shapes
=
[
q_pe
.
shape
,
k_pe
.
shape
]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
.
reshape
(
q_pe
.
shape
[
0
],
-
1
),
k_pe
.
reshape
(
k_pe
.
shape
[
0
],
-
1
)
)
q_pe
,
k_pe
=
q_pe
.
view
(
original_shapes
[
0
]),
k_pe
.
view
(
original_shapes
[
1
])
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
input_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
input_to_float8
(
attn_output
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
)
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
self
.
w_vc
,
attn_output_scale
,
self
.
w_scale
,
torch
.
bfloat16
,
)
else
:
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
transpose
(
0
,
1
),
self
.
w_vc
)
attn_output
=
attn_bmm_output
.
transpose
(
0
,
1
).
flatten
(
1
,
2
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MiniCPM3DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
if
global_server_args_dict
[
"enable_mla"
]:
self
.
self_attn
=
MiniCPM3AttentionMLA
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
self
.
hidden_size
//
config
.
num_attention_heads
,
q_lora_rank
=
(
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
),
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
else
:
self
.
self_attn
=
MiniCPM3Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
self
.
hidden_size
//
config
.
num_attention_heads
,
q_lora_rank
=
(
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
),
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
self
.
mlp
=
MiniCPM3MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
*
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
)
)
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
*
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
)
)
return
hidden_states
,
None
class
MiniCPM3Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
MiniCPM3DecoderLayer
(
config
,
i
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
*
self
.
config
.
scale_emb
else
:
hidden_states
=
input_embeds
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MiniCPM3ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPM3Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if
not
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
not
None
:
input_embeds
=
input_embeds
*
self
.
config
.
scale_emb
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
lm_head_weight
=
self
.
lm_head
.
weight
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
num_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
global_server_args_dict
[
"enable_mla"
]:
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
w_kc
,
w_vc
=
self_attn
.
kv_b_proj
.
weight
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
del
self_attn
.
kv_b_proj
EntryClass
=
MiniCPM3ForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment