Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4060ed37
Unverified
Commit
4060ed37
authored
Oct 24, 2025
by
Yuxuan Zhang
Committed by
GitHub
Oct 24, 2025
Browse files
Refactoring GLM-4.5 and GLM-4.5V related implementations (#11800)
parent
2342605e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
356 additions
and
565 deletions
+356
-565
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+322
-354
python/sglang/srt/models/glm4_moe_nextn.py
python/sglang/srt/models/glm4_moe_nextn.py
+4
-14
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+29
-196
python/sglang/srt/multimodal/processors/glm4v.py
python/sglang/srt/multimodal/processors/glm4v.py
+1
-1
No files found.
python/sglang/srt/models/glm4_moe.py
View file @
4060ed37
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
...
@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.communicator
import
(
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerCommunicator
,
LayerScatterModes
,
LayerScatterModes
,
...
@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
...
@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
...
@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV2ForCausalLM
,
DeepseekV2Model
,
DeepseekV2MoE
,
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
LazyValue
,
add_prefix
,
add_prefix
,
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
...
@@ -80,8 +83,7 @@ from sglang.srt.utils import (
...
@@ -80,8 +83,7 @@ from sglang.srt.utils import (
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
log_info_on_rank0
,
make_layers
,
use_intel_amx_backend
,
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
...
@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_device_sm
=
get_device_sm
()
_device_sm
=
get_device_sm
()
if
_is_cuda
:
from
sgl_kernel
import
dsv3_router_gemm
elif
_is_cpu
and
_is_cpu_amx_available
:
pass
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
...
@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
"Only silu is supported for now."
)
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
...
@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
...
@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
x
,
x
,
forward_batch
=
None
,
forward_batch
=
None
,
should_allreduce_fusion
=
False
,
should_allreduce_fusion
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
):
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
return
x
...
@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
...
@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
self
,
self
,
config
,
config
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
is_nextn
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
is_nextn
=
is_nextn
self
.
weight
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
,
config
.
hidden_size
))
torch
.
empty
((
config
.
n_routed_experts
,
config
.
hidden_size
))
)
)
self
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
),
dtype
=
torch
.
float32
)
torch
.
empty
((
config
.
n_routed_experts
),
dtype
=
torch
.
float32
)
)
)
if
_is_cpu
and
_is_cpu_amx_available
:
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
if
use_intel_amx_backend
(
self
):
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
return
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
,
self
.
weight
,
None
,
# bias
True
,
# is_vnni
)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if
(
_is_cuda
and
not
self
.
is_nextn
and
hidden_states
.
shape
[
0
]
<
4
and
hidden_states
.
shape
[
1
]
==
7168
and
self
.
weight
.
shape
[
0
]
==
256
and
_device_sm
>=
90
):
logits
=
dsv3_router_gemm
(
hidden_states
,
self
.
weight
).
to
(
hidden_states
.
dtype
)
else
:
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
return
logits
return
logits
class
Glm4MoeSparseMoeBlock
(
DeepseekV2MoE
):
class
Glm4MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
is_nextn
:
bool
=
False
,
):
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
num_fused_shared_experts
=
(
0
if
get_global_server_args
().
disable_shared_experts_fusion
else
config
.
n_shared_experts
)
self
.
config
=
config
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
self
.
alt_stream
=
alt_stream
...
@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
"Only silu is supported for now."
"Only silu is supported for now."
)
)
self
.
gate
=
Glm4MoeGate
(
self
.
gate
=
Glm4MoeGate
(
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
))
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
TopK
(
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
self
.
top_k
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
)
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
,
+
self
.
num_fused_shared_experts
top_k
=
self
.
top_k
,
+
get_global_server_args
().
ep_num_redundant_experts
,
layer_id
=
self
.
layer_id
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
layer_id
=
self
.
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
)
self
.
shared_experts_is_int8
=
False
# shared expert
self
.
shared_experts_is_fp8
=
False
if
config
.
n_shared_experts
is
not
None
:
# self.shared_experts_weight_block_size = None
if
config
.
n_shared_experts
is
not
None
and
self
.
num_fused_shared_experts
==
0
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
Glm4MoeMLP
(
self
.
shared_experts
=
Glm4MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
prefix
=
add_prefix
(
"shared_experts"
,
prefix
),
prefix
=
add_prefix
(
"shared_experts"
,
prefix
),
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
self
.
ep_size
>
1
else
{}),
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
or
should_use_flashinfer_cutlass_moe_fp4_allgather
()
else
{}
),
)
)
is_packed_weight
=
hasattr
(
self
.
shared_experts
.
gate_up_proj
.
quant_method
,
"quant_config"
)
self
.
shared_experts_is_int8
=
(
not
is_packed_weight
and
self
.
shared_experts
.
gate_up_proj
.
weight
.
dtype
==
torch
.
int8
)
self
.
shared_experts_is_fp8
=
(
not
is_packed_weight
and
self
.
shared_experts
.
gate_up_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
)
self
.
top_k
=
config
.
num_experts_per_tok
if
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
():
if
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
():
# TODO: we will support tp < ep in the future
# TODO: we will support tp < ep in the future
...
@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
)
)
def
get_moe_weights
(
self
):
return
[
x
.
data
for
name
,
x
in
self
.
experts
.
named_parameters
()
if
name
not
in
[
"correction_bias"
]
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
self
.
_enable_a2a_moe
:
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
(
self
.
alt_stream
is
not
None
and
hidden_states
.
shape
[
0
]
>
0
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
return
self
.
forward_normal_dual_stream
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
def
forward_normal_dual_stream
(
def
forward_normal_dual_stream
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
if
self
.
ep_size
>
1
:
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
if
(
final_hidden_states
=
final_hidden_states_out
self
.
tp_size
>
1
sm
.
tag
(
final_hidden_states
)
and
not
should_allreduce_fusion
if
(
and
not
use_reduce_scatter
self
.
tp_size
>
1
):
and
not
should_allreduce_fusion
final_hidden_states
=
tensor_model_parallel_all_reduce
(
and
not
use_reduce_scatter
final_hidden_states
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
)
):
final_hidden_states
+=
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
+=
shared_output
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
def
forward_normal
(
def
forward_normal
(
...
@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
if
hidden_states
.
shape
[
0
]
>
0
:
self
.
shared_experts
.
gate_up_proj
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
):
# router_logits: (num_tokens, n_experts)
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
self
.
ep_size
>
1
:
if
shared_output
is
not
None
:
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
final_hidden_states
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
)
final_hidden_states
=
final_hidden_states_out
if
shared_output
is
not
None
:
sm
.
tag
(
final_hidden_states
)
final_hidden_states
+=
shared_output
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
_forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
):
shared_output
=
None
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
else
:
if
shared_output
is
not
None
:
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
final_hidden_states
+=
shared_output
final_hidden_states
=
self
.
experts
(
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
:
hidden_states
=
hidden_states
,
final_hidden_states
=
tensor_model_parallel_all_reduce
(
topk_output
=
topk_output
,
final_hidden_states
)
)
if
shared_output
is
not
None
:
final_hidden_states
.
add_
(
shared_output
)
return
final_hidden_states
return
final_hidden_states
def
_forward_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
shared_output
=
None
if
hidden_states
.
shape
[
0
]
>
0
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
return
shared_output
class
Glm4MoeDecoderLayer
(
DeepseekV2DecoderLayer
):
class
Glm4MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
rms_norm_eps
=
config
.
rms_norm_eps
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
attention_bias
=
config
.
attention_bias
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
self_attn
=
Glm4MoeAttention
(
self
.
self_attn
=
Glm4MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
use_qk_norm
=
config
.
use_qk_norm
,
use_qk_norm
=
config
.
use_qk_norm
,
alt_stream
=
alt_stream
,
)
)
self
.
is_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
,
is_nextn
=
is_nextn
)
self
.
is_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
,
is_nextn
=
is_nextn
)
is_previous_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
-
1
,
is_nextn
=
False
)
is_previous_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
-
1
,
is_nextn
=
False
)
num_layers
=
1
if
is_nextn
else
config
.
num_hidden_layers
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
layer_id
=
layer_id
,
num_layers
=
num
_layers
,
num_layers
=
1
if
is_nextn
else
config
.
num_hidden
_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
)
...
@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
alt_stream
=
alt_stream
,
)
)
else
:
else
:
if
enable_moe_dense_fully_dp
():
if
enable_moe_dense_fully_dp
():
...
@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
False
,
allow_reduce_scatter
=
True
,
is_last_layer
=
(
is_nextn
or
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
)
),
)
def
_is_layer_sparse
(
self
,
layer_id
:
int
,
is_nextn
:
bool
)
->
bool
:
return
is_nextn
or
(
self
.
config
.
n_routed_experts
is
not
None
and
layer_id
>=
self
.
config
.
first_k_dense_replace
)
)
def
forward
(
def
forward
(
...
@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
gemm_output_zero_allocator
:
BumpAllocator
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
...
@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
return
hidden_states
,
residual
return
hidden_states
,
residual
class
Glm4MoeModel
(
DeepseekV2Model
):
class
Glm4MoeModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
):
nn
.
Module
.
__init__
(
self
)
super
().
__init__
()
self
.
padding_id
=
config
.
pad_token_id
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
first_k_dense_replace
=
config
.
first_k_dense_replace
self
.
embed_dim
=
config
.
hidden_size
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
is_dp_attention_enabled
(),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
is_dp_attention_enabled
(),
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
[
config
.
num_hidden_layers
,
Glm4MoeDecoderLayer
(
lambda
idx
,
prefix
:
Glm4MoeDecoderLayer
(
config
,
layer_id
=
idx
,
layer_id
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"layers.
{
layer_id
}
"
,
prefix
),
prefix
=
prefix
,
alt_stream
=
self
.
alt_stream
,
alt_stream
=
self
.
alt_stream
,
)
),
for
layer_id
in
range
(
config
.
num_hidden_layers
)
pp_rank
=
self
.
pp_group
.
rank_in_group
,
]
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_last_rank
:
self
.
start_layer
=
0
self
.
norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
s
el
f
.
end_layer
=
config
.
num_hidden_layers
el
se
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
class
Glm4MoeForCausalLM
(
DeepseekV2ForCausalLM
):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
normal_start_layer
=
self
.
start_layer
normal_end_layer
=
self
.
end_layer
if
forward_batch
.
can_run_tbo
:
if
(
self
.
first_k_dense_replace
>
normal_start_layer
and
self
.
first_k_dense_replace
<
normal_end_layer
):
normal_end_layer
=
self
.
first_k_dense_replace
elif
self
.
first_k_dense_replace
<
normal_start_layer
:
normal_end_layer
=
normal_start_layer
=
0
for
i
in
range
(
normal_start_layer
,
normal_end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
)
if
normal_end_layer
!=
self
.
end_layer
:
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
[
normal_end_layer
:
self
.
end_layer
],
enable_tbo
=
True
,
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
input_data_scatter_mode
=
self
.
layers
[
normal_end_layer
-
1
].
layer_scatter_modes
.
layer_output_mode
,
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Glm4MoeForCausalLM
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
config
.
moe_layer_freq
=
1
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
self
.
pp_group
=
get_pp_group
()
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
model
=
Glm4MoeModel
(
self
.
model
=
Glm4MoeModel
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
...
@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
_routed_experts_weights_of_layer
=
LazyValue
(
# For EAGLE3 support
lambda
:
{
self
.
capture_aux_hidden_states
=
False
layer_id
:
layer
.
mlp
.
get_moe_weights
()
for
layer_id
,
layer
in
enumerate
(
self
.
model
.
layers
)
if
isinstance
(
layer
.
mlp
,
DeepseekV2MoE
)
}
)
def
determine_num_fused_shared_experts
(
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
return
self
.
model
.
embed_tokens
):
self
.
num_fused_shared_experts
=
0
if
get_global_server_args
().
disable_shared_experts_fusion
:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
@
torch
.
no_grad
()
disable_reason
=
None
def
forward
(
if
(
self
,
not
_is_cuda
input_ids
:
torch
.
Tensor
,
or
torch
.
cuda
.
get_device_capability
(
"cuda"
)
<
(
8
,
0
)
positions
:
torch
.
Tensor
,
or
self
.
config
.
architectures
[
0
]
!=
architecture
forward_batch
:
ForwardBatch
,
or
self
.
config
.
n_shared_experts
!=
1
input_embeds
:
torch
.
Tensor
=
None
,
):
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
disable_reason
=
"Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
)
->
torch
.
Tensor
:
elif
get_moe_expert_parallel_world_size
()
>
1
:
hidden_states
=
self
.
model
(
disable_reason
=
"Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
)
if
disable_reason
is
not
None
:
get_global_server_args
().
disable_shared_experts_fusion
=
True
if
self
.
pp_group
.
is_last_rank
:
self
.
num_fused_shared_experts
=
0
return
self
.
logits_processor
(
log_info_on_rank0
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
)
)
return
else
:
return
hidden_states
self
.
num_fused_shared_experts
=
self
.
config
.
n_shared_experts
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
@
property
return
self
.
model
.
embed_tokens
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
...
@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
num_fused_shared_experts
>
0
:
assert
self
.
num_fused_shared_experts
==
1
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
not
None
:
if
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
(
self
.
quant_config
.
get_name
()
==
"fp8"
or
self
.
quant_config
.
get_name
()
==
"blockwise_int8"
or
self
.
quant_config
.
get_name
()
==
"compressed_tensors"
):
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
self
.
quant_config
.
get_name
()
==
"awq"
:
suffix_list
=
[
"down_proj.qweight"
,
"down_proj.qzeros"
,
"down_proj.scales"
,
"gate_proj.qweight"
,
"gate_proj.qzeros"
,
"gate_proj.scales"
,
"up_proj.qweight"
,
"up_proj.qzeros"
,
"up_proj.scales"
,
]
elif
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"down_proj.weight_scale_2"
,
"down_proj.input_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"gate_proj.weight_scale_2"
,
"gate_proj.input_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
"up_proj.weight_scale_2"
,
"up_proj.input_scale"
,
]
else
:
raise
ValueError
(
f
"Unsupported shared expert fusion for quantization:
{
self
.
quant_config
.
get_name
()
}
."
)
else
:
suffix_list
=
[
"down_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
,
]
names_to_remove
=
[]
moe_layers
=
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
)
if
not
is_nextn
else
[
nextn_layer_id
]
)
for
moe_layer
in
moe_layers
:
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.
{
moe_layer
}
.mlp.shared_experts.
{
suffix
}
"
)
# online fp8 quantization does not load weight_scale
if
shared_expert_weight_name
not
in
weights_dict
:
continue
weights_list
.
append
(
(
f
"model.layers.
{
moe_layer
}
."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
0
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
,
num_experts
=
self
.
config
.
n_routed_experts
,
)
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
if
is_nextn
:
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_spec_weight_names
=
[
nextn_spec_weight_names
=
[
...
@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# name will be updated to mlp.experts[0].gate_up_proj, which
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts
.
"
in
name
)
and
name
not
in
params_dict
:
if
"mlp.experts"
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Track if this is an expert weight to enable early skipping
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight
=
True
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
params_dict
:
# Expert weight not on this rank, will be skipped below
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
...
@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
)
)
break
break
else
:
else
:
if
is_expert_weight
:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
fuse_qkv_a_proj
and
(
if
name
not
in
params_dict
:
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
continue
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
name
in
params_dict
.
keys
():
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
if
any
(
scale
in
name
for
scale
in
[
"k_scale"
,
"v_scale"
]):
name
=
name
.
replace
(
"_proj"
,
"attn_mqa"
)
else
:
logger
.
warning
(
f
"Unknown scale found in checkpoint:
{
name
}
"
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
n_routed_experts
,
num_groups
=
config
.
n_group
,
)
EntryClass
=
[
Glm4MoeForCausalLM
]
EntryClass
=
[
Glm4MoeForCausalLM
]
python/sglang/srt/models/glm4_moe_nextn.py
View file @
4060ed37
...
@@ -12,7 +12,8 @@
...
@@ -12,7 +12,8 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
import
logging
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.glm4_moe
import
Glm4MoeDecoderLayer
,
Glm4MoeForCausalLM
from
sglang.srt.models.glm4_moe
import
Glm4MoeDecoderLayer
,
Glm4MoeForCausalLM
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
...
@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
dtype
=
torch
.
float32
,
device
=
(
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
),
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
...
@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
...
@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
residual
=
None
residual
=
None
with
get_global_expert_distribution_recorder
().
disable_this_region
():
with
get_global_expert_distribution_recorder
().
disable_this_region
():
hidden_states
,
residual
=
self
.
decoder
(
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
residual
)
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
...
@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
...
@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
class
Glm4MoeForCausalLMNextN
(
Glm4MoeForCausalLM
):
class
Glm4MoeForCausalLMNextN
(
Glm4MoeForCausalLM
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
...
@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLMNextN"
)
self
.
model
=
Glm4MoeModelNextN
(
self
.
model
=
Glm4MoeModelNextN
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
...
...
python/sglang/srt/models/glm4v_moe.py
View file @
4060ed37
...
@@ -6,13 +6,10 @@ import torch
...
@@ -6,13 +6,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.glm4v_moe.configuration_glm4v_moe
import
Glm4vMoeConfig
from
transformers.models.glm4v_moe.configuration_glm4v_moe
import
Glm4vMoeConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
get_moe_expert_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
.layer
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
...
@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
...
@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.models.glm4_moe
import
Glm4MoeModel
from
sglang.srt.models.glm4_moe
import
Glm4MoeModel
from
sglang.srt.models.glm4v
import
Glm4vForConditionalGeneration
,
Glm4vVisionModel
from
sglang.srt.models.glm4v
import
Glm4vForConditionalGeneration
,
Glm4vVisionModel
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
log_info_on_rank0
from
sglang.srt.utils
import
add_prefix
,
is_cuda
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
->
None
:
)
->
None
:
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
config
.
moe_layer_freq
=
1
self
.
config
=
config
self
.
config
=
config
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
num_fused_shared_experts
=
(
self
.
num_fused_shared_experts
=
(
0
0
if
get_global_server_args
().
disable_shared_experts_fusion
if
get_global_server_args
().
disable_shared_experts_fusion
...
@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# For EAGLE3 support
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
self
.
capture_aux_hidden_states
=
False
def
determine_num_fused_shared_experts
(
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
):
self
.
num_fused_shared_experts
=
0
if
get_global_server_args
().
disable_shared_experts_fusion
:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason
=
None
if
(
not
_is_cuda
or
torch
.
cuda
.
get_device_capability
(
"cuda"
)
<
(
8
,
0
)
or
self
.
config
.
architectures
[
0
]
!=
architecture
or
self
.
config
.
n_shared_experts
!=
1
):
disable_reason
=
"Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif
get_moe_expert_parallel_world_size
()
>
1
:
disable_reason
=
"Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if
disable_reason
is
not
None
:
get_global_server_args
().
disable_shared_experts_fusion
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
)
return
self
.
num_fused_shared_experts
=
self
.
config
.
n_shared_experts
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
...
@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
num_fused_shared_experts
>
0
:
assert
self
.
num_fused_shared_experts
==
1
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
not
None
:
if
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
(
self
.
quant_config
.
get_name
()
==
"fp8"
or
self
.
quant_config
.
get_name
()
==
"blockwise_int8"
or
self
.
quant_config
.
get_name
()
==
"compressed_tensors"
):
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
self
.
quant_config
.
get_name
()
==
"awq"
:
suffix_list
=
[
"down_proj.qweight"
,
"down_proj.qzeros"
,
"down_proj.scales"
,
"gate_proj.qweight"
,
"gate_proj.qzeros"
,
"gate_proj.scales"
,
"up_proj.qweight"
,
"up_proj.qzeros"
,
"up_proj.scales"
,
]
elif
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"down_proj.weight_scale_2"
,
"down_proj.input_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"gate_proj.weight_scale_2"
,
"gate_proj.input_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
"up_proj.weight_scale_2"
,
"up_proj.input_scale"
,
]
else
:
raise
ValueError
(
f
"Unsupported shared expert fusion for quantization:
{
self
.
quant_config
.
get_name
()
}
."
)
else
:
suffix_list
=
[
"down_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
,
]
names_to_remove
=
[]
moe_layers
=
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
)
if
not
is_nextn
else
[
nextn_layer_id
]
)
for
moe_layer
in
moe_layers
:
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.
{
moe_layer
}
.mlp.shared_experts.
{
suffix
}
"
)
# online fp8 quantization does not load weight_scale
if
shared_expert_weight_name
not
in
weights_dict
:
continue
weights_list
.
append
(
(
f
"model.layers.
{
moe_layer
}
."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
0
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
,
num_experts
=
self
.
config
.
n_routed_experts
,
)
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
if
is_nextn
:
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_spec_weight_names
=
[
nextn_spec_weight_names
=
[
...
@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# name will be updated to mlp.experts[0].gate_up_proj, which
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts
.
"
in
name
)
and
name
not
in
params_dict
:
if
"mlp.experts"
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Track if this is an expert weight to enable early skipping
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight
=
True
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
params_dict
:
# Expert weight not on this rank, will be skipped below
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
...
@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
)
break
break
else
:
else
:
if
is_expert_weight
:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if
"visual"
in
name
:
if
"visual"
in
name
:
# adapt to VisionAttention
# adapt to VisionAttention
for GLM-V
name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
fuse_qkv_a_proj
and
(
if
name
not
in
params_dict
:
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
continue
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
if
name
in
params_dict
.
keys
():
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
if
any
(
scale
in
name
for
scale
in
[
"k_scale"
,
"v_scale"
]):
name
=
name
.
replace
(
"_proj"
,
"attn_mqa"
)
else
:
logger
.
warning
(
f
"Unknown scale found in checkpoint:
{
name
}
"
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
...
@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self
.
config
,
name
,
loaded_weight
self
.
config
,
name
,
loaded_weight
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
[
Glm4vMoeForConditionalGeneration
]
EntryClass
=
[
Glm4vMoeForConditionalGeneration
]
python/sglang/srt/multimodal/processors/glm4v.py
View file @
4060ed37
...
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
...
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
# GLM-
4.1V and GLM-4.5
V specific tokens
# GLM-V specific tokens
self
.
IMAGE_TOKEN
=
"<|image|>"
self
.
IMAGE_TOKEN
=
"<|image|>"
self
.
VIDEO_TOKEN
=
"<|video|>"
self
.
VIDEO_TOKEN
=
"<|video|>"
self
.
IMAGE_START_TOKEN
=
"<|begin_of_image|>"
self
.
IMAGE_START_TOKEN
=
"<|begin_of_image|>"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment