Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a88b006e
Unverified
Commit
a88b006e
authored
Oct 27, 2025
by
Yuxuan Zhang
Committed by
GitHub
Oct 27, 2025
Browse files
GLM-4-0414 and GLM-4.1V Code Refactor (#12117)
parent
ce112c07
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
679 additions
and
173 deletions
+679
-173
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+92
-40
python/sglang/srt/models/glm4.py
python/sglang/srt/models/glm4.py
+391
-77
python/sglang/srt/models/glm4v.py
python/sglang/srt/models/glm4v.py
+196
-55
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+0
-1
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
a88b006e
...
...
@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h
:
tl
.
constexpr
,
mrope_section_w
:
tl
.
constexpr
,
is_interleaved
:
tl
.
constexpr
,
is_neox_style
:
tl
.
constexpr
,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
...
...
@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_half_k_offsets
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
if
is_neox_style
:
first_half_q_offsets
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_half_k_offsets
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
rd
//
2
)
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
rd
//
2
)
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
tl
.
store
(
q_ptr
+
first_half_q_offsets
,
new_q_tile_1
,
mask
=
first_q_mask
)
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
second_half_q_offsets
,
new_q_tile_2
,
mask
=
second_q_mask
)
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
tl
.
store
(
k_ptr
+
first_half_k_offsets
,
new_k_tile_1
,
mask
=
first_k_mask
)
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
tl
.
store
(
k_ptr
+
second_half_k_offsets
,
new_k_tile_2
,
mask
=
second_k_mask
)
else
:
base_q
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
base_k
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
even_idx
=
2
*
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
odd_idx
=
even_idx
+
1
even_q_offsets
=
base_q
+
even_idx
odd_q_offsets
=
base_q
+
odd_idx
even_k_offsets
=
base_k
+
even_idx
odd_k_offsets
=
base_k
+
odd_idx
idx_mask
=
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
(
rd
//
2
)
qn_mask
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
kn_mask
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
even_q_mask
=
qn_mask
&
idx_mask
odd_q_mask
=
qn_mask
&
idx_mask
even_k_mask
=
kn_mask
&
idx_mask
odd_k_mask
=
kn_mask
&
idx_mask
q_tile_1
=
tl
.
load
(
q_ptr
+
even_q_offsets
,
mask
=
even_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
even_k_offsets
,
mask
=
even_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half
_q_offsets
,
mask
=
secon
d_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half
_k_offsets
,
mask
=
secon
d_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_2
=
tl
.
load
(
q_ptr
+
odd
_q_offsets
,
mask
=
od
d_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
odd
_k_offsets
,
mask
=
od
d_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
tl
.
store
(
q_ptr
+
first_half_q_offsets
,
new_q_tile_1
,
mask
=
first_q_mask
)
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
second_half_q_offsets
,
new_q_tile_2
,
mask
=
second_q_mask
)
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# NeoX-style rotary embedding:
# Each (even, odd) channel pair forms one rotation arm.
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
tl
.
store
(
q_ptr
+
even_q_offsets
,
new_q_tile_1
,
mask
=
even_q_mask
)
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
odd_q_offsets
,
new_q_tile_2
,
mask
=
odd_q_mask
)
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
tl
.
store
(
k_ptr
+
first_half
_k_offsets
,
new_k_tile_1
,
mask
=
first
_k_mask
)
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
tl
.
store
(
k_ptr
+
second_half
_k_offsets
,
new_k_tile_2
,
mask
=
secon
d_k_mask
)
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
tl
.
store
(
k_ptr
+
even
_k_offsets
,
new_k_tile_1
,
mask
=
even
_k_mask
)
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
tl
.
store
(
k_ptr
+
odd
_k_offsets
,
new_k_tile_2
,
mask
=
od
d_k_mask
)
def
triton_mrope
(
...
...
@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size
:
int
,
rotary_dim
:
int
,
mrope_interleaved
:
bool
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""The mrope triton kernel.
...
...
@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section
[
1
],
mrope_section
[
2
],
mrope_interleaved
,
is_neox_style
,
)
return
q
,
k
...
...
@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
head_size
,
self
.
rotary_dim
,
self
.
mrope_interleaved
,
self
.
is_neox_style
,
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
...
...
python/sglang/srt/models/glm4.py
View file @
a88b006e
...
...
@@ -15,46 +15,119 @@
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
"""Inference-only GLM4 model compatible with
THUDM
weights."""
"""Inference-only GLM
-4-041
4 model compatible with
HuggingFace
weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Glm4Config
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaMLP
as
Glm4MLP
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
)
from
sglang.srt.utils
import
add_prefix
,
make_layers
Glm4Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Glm4MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
reduce_results
=
reduce_results
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
forward_batch
=
None
,
use_reduce_scatter
:
bool
=
False
,
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
use_reduce_scatter
,
)
return
x
class
Glm4Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
Optional
[
int
]
=
None
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
1000000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
131072
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
partial_rotary_factor
:
float
=
0.5
,
prefix
:
str
=
""
,
):
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention
_heads
self
.
total_num_heads
=
num
_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
config
.
num_key_value
_heads
self
.
total_num_kv_heads
=
num_kv
_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
...
...
@@ -63,27 +136,30 @@ class Glm4Attention(nn.Module):
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
0.5
)
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
config
.
hidden_size
//
self
.
total_num_heads
if
head_dim
is
not
None
:
self
.
head_dim
=
head_dim
else
:
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
self
.
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
partial_rotary_factor
=
partial_rotary_factor
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
attention_bias
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
...
...
@@ -92,9 +168,10 @@ class Glm4Attention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
config
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
self
.
rope_scaling
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
partial_rotary_factor
=
partial_rotary_factor
,
is_neox_style
=
False
,
)
...
...
@@ -117,14 +194,9 @@ class Glm4Attention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
context_layer
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
,
)
attn_output
,
_
=
self
.
o_proj
(
context_layer
)
return
attn_output
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Glm4DecoderLayer
(
nn
.
Module
):
...
...
@@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module):
def
__init__
(
self
,
config
,
layer_id
:
int
,
config
:
Glm4Config
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
# Self attention.
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
None
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Glm4Attention
(
config
,
layer_id
,
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
)
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
head_dim
=
head_dim
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
partial_rotary_factor
=
partial_rotary_factor
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
# MLP
...
...
@@ -199,54 +291,125 @@ class Glm4Model(nn.Module):
config
:
Glm4Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
decoder_layer_type
:
type
[
nn
.
Module
]
=
Glm4DecoderLayer
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
self
.
layers
=
make_layers
(
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
enable_tp
=
not
is_dp_attention_enabled
(),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
# Use the provided decoder layer type or default to Glm4DecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Glm4DecoderLayer
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
Glm4DecoderLayer
(
config
=
config
,
layer_id
=
idx
,
quant_config
=
quant_config
,
prefix
=
prefix
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
alt_stream
=
alt_stream
,
),
prefix
=
"model.layers"
,
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# For EAGLE3 support
self
.
layers_to_capture
=
[]
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
embed_tokens
def
dtype
(
self
)
->
torch
.
dtype
:
return
next
(
self
.
parameters
()).
dtype
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
hidden_states
=
input_embeds
residual
=
None
for
layer
in
self
.
layers
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
aux_hidden_states
=
[]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
return
hidden_states
,
aux_hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
hasattr
(
layer_self_attn
.
attn
,
"k_scale"
):
layer_self_attn
.
attn
.
k_scale
=
scaling_factor
layer_self_attn
.
attn
.
v_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling factor attribute!"
)
class
Glm4ForCausalLM
(
nn
.
Module
):
...
...
@@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module):
config
:
Glm4Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
)
->
None
:
super
().
__init__
()
self
.
config
:
Glm4Config
=
config
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Glm4Model
(
config
,
quant_config
,
add_prefix
(
"model"
,
prefix
))
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
model
=
Glm4Model
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
# handle the lm head on different pp ranks
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
world_size
==
1
and
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
"lm_head"
,
)
# ranks other than the last rank will have a placeholder layer
self
.
lm_head
=
PPMissingLayer
()
# perform weight tying for PP
if
self
.
pp_group
.
world_size
>
1
and
config
.
tie_word_embeddings
:
if
self
.
pp_group
.
is_first_rank
:
self
.
pp_group
.
send
(
self
.
model
.
embed_tokens
.
weight
,
dst
=
self
.
pp_group
.
last_rank
)
else
:
emb_token_weight
=
self
.
pp_group
.
recv
(
size
=
(
config
.
vocab_size
,
config
.
hidden_size
),
dtype
=
next
(
self
.
model
.
parameters
()).
dtype
,
src
=
self
.
pp_group
.
first_rank
,
)
self
.
lm_head
.
weight
.
copy_
(
emb_token_weight
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
(
input_ids
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
hidden_states
@
torch
.
no_grad
()
def
forward_split_prefill
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
split_interval
:
Tuple
[
int
,
int
],
# [start, end) 0-based
input_embeds
:
torch
.
Tensor
=
None
,
):
start
,
end
=
split_interval
# embed
if
start
==
0
:
if
input_embeds
is
None
:
forward_batch
.
hidden_states
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
forward_batch
.
hidden_states
=
input_embeds
# decoder layer
for
i
in
range
(
start
,
end
):
layer
=
self
.
model
.
layers
[
i
]
forward_batch
.
hidden_states
,
forward_batch
.
residual
=
layer
(
positions
,
forward_batch
.
hidden_states
,
forward_batch
,
forward_batch
.
residual
,
)
if
end
==
self
.
model
.
config
.
num_hidden_layers
:
# norm
hidden_states
,
_
=
self
.
model
.
norm
(
forward_batch
.
hidden_states
,
forward_batch
.
residual
)
forward_batch
.
hidden_states
=
hidden_states
# logits process
result
=
self
.
logits_processor
(
input_ids
,
forward_batch
.
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
result
=
None
return
result
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name,
weight
_name, shard_id)
# (param_name,
shard
_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
if
self
.
pp_group
.
world_size
>
1
and
self
.
pp_group
.
is_last_rank
:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights
=
next
(
filter
(
lambda
x
:
x
[
0
]
==
"model.embed_tokens.weight"
,
weights
)
)[
1
]
loaded_weight
=
embed_token_weights
else
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
...
...
@@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
else
:
raise
KeyError
(
f
"Parameter '
{
name
}
' not found in model."
)
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
EntryClass
=
[
Glm4ForCausalLM
]
python/sglang/srt/models/glm4v.py
View file @
a88b006e
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import
logging
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers.models.glm4v.configuration_glm4v
import
Glm4vConfig
,
Glm4vVisionConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.
dp_
attention
import
get_a
ttention
_tp_size
from
sglang.srt.layers.attention
.vision
import
VisionA
ttention
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.glm4
import
Glm4Model
from
sglang.srt.models.qwen2_5_vl
import
(
Qwen2_5_VisionBlock
,
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
...
...
@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
in_features
,
output_sizes
=
[
hidden_features
]
*
2
,
output_sizes
=
[
hidden_features
]
*
2
,
# [gate_proj, up_proj]
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
...
...
@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return
x
class
Glm4vVisionBlock
(
Qwen2_5_VisionBlock
):
class
Glm4vVisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4vVisionConfig
,
norm_layer
:
Optional
[
nn
.
Module
]
=
None
,
dim
:
int
,
intermediate_dim
:
int
,
num_heads
:
int
,
attn_implementation
:
Optional
[
str
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
num_dummy_heads
:
int
=
0
,
rms_norm_eps
:
float
=
1e-5
,
)
->
None
:
super
().
__init__
(
dim
=
config
.
hidden_size
,
intermediate_dim
=
config
.
out_hidden_size
,
num_heads
=
config
.
num_heads
,
hidden_act
=
config
.
hidden_act
,
norm_layer
=
norm_layer
,
super
().
__init__
()
self
.
norm1
=
RMSNorm
(
dim
,
eps
=
rms_norm_eps
)
self
.
norm2
=
RMSNorm
(
dim
,
eps
=
rms_norm_eps
)
if
attn_implementation
is
None
:
softmax_in_single_precision
=
False
qkv_backend
=
None
flatten_batch
=
True
elif
attn_implementation
==
"sdpa"
:
softmax_in_single_precision
=
False
qkv_backend
=
"sdpa"
flatten_batch
=
True
elif
attn_implementation
==
"flash_attention_2"
:
softmax_in_single_precision
=
False
qkv_backend
=
"triton_attn"
flatten_batch
=
True
elif
attn_implementation
==
"eager"
:
softmax_in_single_precision
=
True
qkv_backend
=
"sdpa"
flatten_batch
=
True
elif
attn_implementation
==
"flash_attention_3"
:
softmax_in_single_precision
=
False
qkv_backend
=
"fa3"
flatten_batch
=
True
self
.
attn
=
VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
use_qkv_parallel
=
True
,
rotary_embed
=
"normal"
,
proj_bias
=
True
,
qkv_backend
=
qkv_backend
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
flatten_batch
,
quant_config
=
quant_config
,
prefix
=
prefix
,
num_dummy_heads
=
config
.
num_dummy_heads
,
rms_norm_eps
=
config
.
rms_norm_eps
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
num_dummy_heads
=
num_dummy_heads
,
)
self
.
mlp
=
Glm4vVisionMLP
(
config
.
hidden_size
,
config
.
out_hidden_size
,
bias
=
False
,
dim
,
intermediate_dim
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
,
B
,
H
=
x
.
shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d
=
x
.
reshape
(
-
1
,
H
)
hidden_states
=
self
.
norm1
(
x2d
).
reshape
(
S
,
B
,
H
)
# Attention expects [B, S, H]
hidden_states
=
rearrange
(
hidden_states
,
"s b h -> b s h"
)
attn
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
,
)
attn
=
rearrange
(
attn
,
"b s h -> s b h"
)
# norm2 with fused residual-add: also 2D
attn2d
=
attn
.
reshape
(
-
1
,
H
)
x_norm_2d
,
x_after_add_2d
=
self
.
norm2
(
x2d
,
residual
=
attn2d
)
x_norm
=
x_norm_2d
.
reshape
(
S
,
B
,
H
)
x_after_add
=
x_after_add_2d
.
reshape
(
S
,
B
,
H
)
# MLP and final residual
mlp_out
=
self
.
mlp
(
x_norm
)
x
=
x_after_add
+
mlp_out
return
x
class
Glm4vVisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
...
...
@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def
__init__
(
self
,
vision_config
:
Glm4vVisionConfig
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
...
...
@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size
=
self
.
hidden_size
,
)
norm_layer
=
partial
(
Glm4vRMSNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Glm4vVisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
(
[
Glm4vVisionBlock
(
config
=
vision_config
,
norm_layer
=
norm_layer
,
dim
=
self
.
hidden_size
,
intermediate_dim
=
self
.
out_hidden_size
,
num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"blocks.
{
layer_idx
}
"
,
prefix
),
rms_norm_eps
=
vision_config
.
rms_norm_eps
,
)
for
layer_idx
in
range
(
depth
)
]
...
...
@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return
x
class
Glm4vForConditionalGeneration
(
Qwen2_5_VLForConditionalGeneration
):
class
Glm4vForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4vConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
nn
.
Module
.
__init__
(
self
)
super
()
.
__init__
()
self
.
config
=
config
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
model
=
Glm4Model
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
self
.
visual
=
Glm4vVisionModel
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
model
=
Glm4Model
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
...
...
@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values
=
torch
.
cat
(
[
item
.
feature
.
squeeze
(
0
)
for
item
in
items
],
dim
=
0
...
...
@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds
=
torch
.
split
(
video_embeds
,
split_sizes
)
return
torch
.
cat
(
video_embeds
)
def
_update_hf_config
(
self
):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
tp_size
=
get_attention_tp_size
()
num_heads
=
self
.
config
.
vision_config
.
num_heads
head_dim
=
self
.
config
.
vision_config
.
hidden_size
//
num_heads
num_dummy_heads
=
0
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
if
num_heads
%
tp_size
!=
0
:
num_dummy_heads
=
(
(
num_heads
+
tp_size
-
1
)
//
tp_size
)
*
tp_size
-
num_heads
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
self
.
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
if
not
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
if
self
.
is_mrope_enabled
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
multimodal_model
=
self
,
positions
=
positions
,
)
setattr
(
self
.
config
.
vision_config
,
"head_dim"
,
head_dim
)
setattr
(
self
.
config
.
vision_config
,
"num_dummy_heads"
,
num_dummy_heads
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
_pad_vit_attn_dummy_heads
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
):
"""pad attn qkv weights for dummy heads"""
...
...
@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"language_model."
in
name
:
name
=
name
.
replace
(
"language_model."
,
""
)
if
"model.visual."
in
name
:
name
=
name
.
replace
(
"model.visual."
,
"visual."
)
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"language_model"
in
name
:
name
=
name
.
replace
(
r
"model.language_model."
,
r
"model."
)
if
"model.visual."
in
name
:
name
=
name
.
replace
(
"model.visual."
,
"visual."
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
...
...
@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
weight_loader
(
param
,
loaded_weight
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
del
self
.
lm_head
.
weight
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
EntryClass
=
[
Glm4vForConditionalGeneration
]
python/sglang/srt/models/glm4v_moe.py
View file @
a88b006e
...
...
@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
self
.
visual
=
Glm4vVisionModel
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment