Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a88b006e
Unverified
Commit
a88b006e
authored
Oct 27, 2025
by
Yuxuan Zhang
Committed by
GitHub
Oct 27, 2025
Browse files
GLM-4-0414 and GLM-4.1V Code Refactor (#12117)
parent
ce112c07
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
679 additions
and
173 deletions
+679
-173
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+92
-40
python/sglang/srt/models/glm4.py
python/sglang/srt/models/glm4.py
+391
-77
python/sglang/srt/models/glm4v.py
python/sglang/srt/models/glm4v.py
+196
-55
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+0
-1
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
a88b006e
...
@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
...
@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h
:
tl
.
constexpr
,
mrope_section_h
:
tl
.
constexpr
,
mrope_section_w
:
tl
.
constexpr
,
mrope_section_w
:
tl
.
constexpr
,
is_interleaved
:
tl
.
constexpr
,
is_interleaved
:
tl
.
constexpr
,
is_neox_style
:
tl
.
constexpr
,
):
):
# Adapted from
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
...
@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
...
@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# program instance (i.e. for the current token) separately
# ####################################################################
# ####################################################################
# left half of the head
# left half of the head
first_half_q_offsets
=
(
if
is_neox_style
:
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_half_q_offsets
=
(
)
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_half_k_offsets
=
(
)
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_half_k_offsets
=
(
)
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
)
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
)
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
)
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
)
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
sin_row
.
dtype
)
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
sin_row
.
dtype
)
)
# right half of the head
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
rd
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
rd
//
2
)
second_q_mask
=
first_q_mask
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
second_k_mask
=
first_k_mask
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
tl
.
store
(
q_ptr
+
first_half_q_offsets
,
new_q_tile_1
,
mask
=
first_q_mask
)
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
second_half_q_offsets
,
new_q_tile_2
,
mask
=
second_q_mask
)
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
tl
.
store
(
k_ptr
+
first_half_k_offsets
,
new_k_tile_1
,
mask
=
first_k_mask
)
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
tl
.
store
(
k_ptr
+
second_half_k_offsets
,
new_k_tile_2
,
mask
=
second_k_mask
)
else
:
base_q
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
base_k
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
even_idx
=
2
*
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
odd_idx
=
even_idx
+
1
even_q_offsets
=
base_q
+
even_idx
odd_q_offsets
=
base_q
+
odd_idx
even_k_offsets
=
base_k
+
even_idx
odd_k_offsets
=
base_k
+
odd_idx
idx_mask
=
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
(
rd
//
2
)
qn_mask
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
kn_mask
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
even_q_mask
=
qn_mask
&
idx_mask
odd_q_mask
=
qn_mask
&
idx_mask
even_k_mask
=
kn_mask
&
idx_mask
odd_k_mask
=
kn_mask
&
idx_mask
q_tile_1
=
tl
.
load
(
q_ptr
+
even_q_offsets
,
mask
=
even_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
even_k_offsets
,
mask
=
even_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half
_q_offsets
,
mask
=
secon
d_q_mask
,
other
=
0
).
to
(
q_tile_2
=
tl
.
load
(
q_ptr
+
odd
_q_offsets
,
mask
=
od
d_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
sin_row
.
dtype
)
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half
_k_offsets
,
mask
=
secon
d_k_mask
,
other
=
0
).
to
(
k_tile_2
=
tl
.
load
(
k_ptr
+
odd
_k_offsets
,
mask
=
od
d_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
sin_row
.
dtype
)
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# Since cos and sin are now half-size,
# NeoX-style rotary embedding:
# we use the same cos_row and sin_row for both halves
# Each (even, odd) channel pair forms one rotation arm.
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
tl
.
store
(
q_ptr
+
first_half_q_offsets
,
new_q_tile_1
,
mask
=
first_q_mask
)
new_q_tile_1
=
q_tile_1
*
cos_row
-
q_tile_2
*
sin_row
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
even_q_offsets
,
new_q_tile_1
,
mask
=
even_q_mask
)
tl
.
store
(
q_ptr
+
second_half_q_offsets
,
new_q_tile_2
,
mask
=
second_q_mask
)
new_q_tile_2
=
q_tile_2
*
cos_row
+
q_tile_1
*
sin_row
tl
.
store
(
q_ptr
+
odd_q_offsets
,
new_q_tile_2
,
mask
=
odd_q_mask
)
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
new_k_tile_1
=
k_tile_1
*
cos_row
-
k_tile_2
*
sin_row
tl
.
store
(
k_ptr
+
first_half
_k_offsets
,
new_k_tile_1
,
mask
=
first
_k_mask
)
tl
.
store
(
k_ptr
+
even
_k_offsets
,
new_k_tile_1
,
mask
=
even
_k_mask
)
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
new_k_tile_2
=
k_tile_2
*
cos_row
+
k_tile_1
*
sin_row
tl
.
store
(
k_ptr
+
second_half
_k_offsets
,
new_k_tile_2
,
mask
=
secon
d_k_mask
)
tl
.
store
(
k_ptr
+
odd
_k_offsets
,
new_k_tile_2
,
mask
=
od
d_k_mask
)
def
triton_mrope
(
def
triton_mrope
(
...
@@ -1180,6 +1229,7 @@ def triton_mrope(
...
@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size
:
int
,
head_size
:
int
,
rotary_dim
:
int
,
rotary_dim
:
int
,
mrope_interleaved
:
bool
,
mrope_interleaved
:
bool
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""The mrope triton kernel.
"""The mrope triton kernel.
...
@@ -1230,6 +1280,7 @@ def triton_mrope(
...
@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section
[
1
],
mrope_section
[
1
],
mrope_section
[
2
],
mrope_section
[
2
],
mrope_interleaved
,
mrope_interleaved
,
is_neox_style
,
)
)
return
q
,
k
return
q
,
k
...
@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
head_size
,
self
.
head_size
,
self
.
rotary_dim
,
self
.
rotary_dim
,
self
.
mrope_interleaved
,
self
.
mrope_interleaved
,
self
.
is_neox_style
,
)
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
...
...
python/sglang/srt/models/glm4.py
View file @
a88b006e
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/glm4v.py
View file @
a88b006e
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import
logging
import
logging
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers.models.glm4v.configuration_glm4v
import
Glm4vConfig
,
Glm4vVisionConfig
from
transformers.models.glm4v.configuration_glm4v
import
Glm4vConfig
,
Glm4vVisionConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.
dp_
attention
import
get_a
ttention
_tp_size
from
sglang.srt.layers.attention
.vision
import
VisionA
ttention
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
...
@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.glm4
import
Glm4Model
from
sglang.srt.models.glm4
import
Glm4Model
from
sglang.srt.models.qwen2_5_vl
import
(
Qwen2_5_VisionBlock
,
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
...
@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
...
@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
in_features
,
input_size
=
in_features
,
output_sizes
=
[
hidden_features
]
*
2
,
output_sizes
=
[
hidden_features
]
*
2
,
# [gate_proj, up_proj]
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
...
@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
...
@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return
x
return
x
class
Glm4vVisionBlock
(
Qwen2_5_VisionBlock
):
class
Glm4vVisionBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Glm4vVisionConfig
,
dim
:
int
,
norm_layer
:
Optional
[
nn
.
Module
]
=
None
,
intermediate_dim
:
int
,
num_heads
:
int
,
attn_implementation
:
Optional
[
str
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
num_dummy_heads
:
int
=
0
,
rms_norm_eps
:
float
=
1e-5
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
()
dim
=
config
.
hidden_size
,
self
.
norm1
=
RMSNorm
(
dim
,
eps
=
rms_norm_eps
)
intermediate_dim
=
config
.
out_hidden_size
,
self
.
norm2
=
RMSNorm
(
dim
,
eps
=
rms_norm_eps
)
num_heads
=
config
.
num_heads
,
hidden_act
=
config
.
hidden_act
,
if
attn_implementation
is
None
:
norm_layer
=
norm_layer
,
softmax_in_single_precision
=
False
qkv_backend
=
None
flatten_batch
=
True
elif
attn_implementation
==
"sdpa"
:
softmax_in_single_precision
=
False
qkv_backend
=
"sdpa"
flatten_batch
=
True
elif
attn_implementation
==
"flash_attention_2"
:
softmax_in_single_precision
=
False
qkv_backend
=
"triton_attn"
flatten_batch
=
True
elif
attn_implementation
==
"eager"
:
softmax_in_single_precision
=
True
qkv_backend
=
"sdpa"
flatten_batch
=
True
elif
attn_implementation
==
"flash_attention_3"
:
softmax_in_single_precision
=
False
qkv_backend
=
"fa3"
flatten_batch
=
True
self
.
attn
=
VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
use_qkv_parallel
=
True
,
rotary_embed
=
"normal"
,
proj_bias
=
True
,
qkv_backend
=
qkv_backend
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
flatten_batch
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
num_dummy_heads
=
config
.
num_dummy_heads
,
num_dummy_heads
=
num_dummy_heads
,
rms_norm_eps
=
config
.
rms_norm_eps
,
)
)
self
.
mlp
=
Glm4vVisionMLP
(
self
.
mlp
=
Glm4vVisionMLP
(
config
.
hidden_size
,
dim
,
config
.
out_hidden_size
,
intermediate_dim
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
,
B
,
H
=
x
.
shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d
=
x
.
reshape
(
-
1
,
H
)
hidden_states
=
self
.
norm1
(
x2d
).
reshape
(
S
,
B
,
H
)
# Attention expects [B, S, H]
hidden_states
=
rearrange
(
hidden_states
,
"s b h -> b s h"
)
attn
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
,
)
attn
=
rearrange
(
attn
,
"b s h -> s b h"
)
# norm2 with fused residual-add: also 2D
attn2d
=
attn
.
reshape
(
-
1
,
H
)
x_norm_2d
,
x_after_add_2d
=
self
.
norm2
(
x2d
,
residual
=
attn2d
)
x_norm
=
x_norm_2d
.
reshape
(
S
,
B
,
H
)
x_after_add
=
x_after_add_2d
.
reshape
(
S
,
B
,
H
)
# MLP and final residual
mlp_out
=
self
.
mlp
(
x_norm
)
x
=
x_after_add
+
mlp_out
return
x
class
Glm4vVisionPatchEmbed
(
nn
.
Module
):
class
Glm4vVisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
...
@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
vision_config
:
Glm4vVisionConfig
,
vision_config
:
Glm4vVisionConfig
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
...
@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
...
@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
)
)
norm_layer
=
partial
(
Glm4vRMSNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Glm4vVisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
Glm4vVisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
Glm4vVisionBlock
(
Glm4vVisionBlock
(
config
=
vision_config
,
dim
=
self
.
hidden_size
,
norm_layer
=
norm_layer
,
intermediate_dim
=
self
.
out_hidden_size
,
num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"blocks.
{
layer_idx
}
"
,
prefix
),
prefix
=
add_prefix
(
f
"blocks.
{
layer_idx
}
"
,
prefix
),
rms_norm_eps
=
vision_config
.
rms_norm_eps
,
)
)
for
layer_idx
in
range
(
depth
)
for
layer_idx
in
range
(
depth
)
]
]
...
@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
...
@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return
x
return
x
class
Glm4vForConditionalGeneration
(
Qwen2_5_VLForConditionalGeneration
):
class
Glm4vForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Glm4vConfig
,
config
:
Glm4vConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
nn
.
Module
.
__init__
(
self
)
super
()
.
__init__
()
self
.
config
=
config
self
.
config
=
config
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
model
=
Glm4Model
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
self
.
visual
=
Glm4vVisionModel
(
self
.
visual
=
Glm4vVisionModel
(
config
.
vision_config
,
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
)
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
model
=
Glm4Model
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
else
:
...
@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
...
@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
# For EAGLE3 support
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
self
.
capture_aux_hidden_states
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values
=
torch
.
cat
(
pixel_values
=
torch
.
cat
(
[
item
.
feature
.
squeeze
(
0
)
for
item
in
items
],
dim
=
0
[
item
.
feature
.
squeeze
(
0
)
for
item
in
items
],
dim
=
0
...
@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
...
@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds
=
torch
.
split
(
video_embeds
,
split_sizes
)
video_embeds
=
torch
.
split
(
video_embeds
,
split_sizes
)
return
torch
.
cat
(
video_embeds
)
return
torch
.
cat
(
video_embeds
)
def
_update_hf_config
(
self
):
def
get_input_embeddings
(
self
):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
return
self
.
model
.
embed_tokens
tp_size
=
get_attention_tp_size
()
num_heads
=
self
.
config
.
vision_config
.
num_heads
head_dim
=
self
.
config
.
vision_config
.
hidden_size
//
num_heads
num_dummy_heads
=
0
if
num_heads
%
tp_size
!=
0
:
@
torch
.
no_grad
()
num_dummy_heads
=
(
def
forward
(
(
num_heads
+
tp_size
-
1
)
//
tp_size
self
,
)
*
tp_size
-
num_heads
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
self
.
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
if
not
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
if
self
.
is_mrope_enabled
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
multimodal_model
=
self
,
positions
=
positions
,
)
setattr
(
self
.
config
.
vision_config
,
"head_dim"
,
head_dim
)
aux_hidden_states
=
None
setattr
(
self
.
config
.
vision_config
,
"num_dummy_heads"
,
num_dummy_heads
)
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
_pad_vit_attn_dummy_heads
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
):
def
_pad_vit_attn_dummy_heads
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
):
"""pad attn qkv weights for dummy heads"""
"""pad attn qkv weights for dummy heads"""
...
@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
...
@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
]
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"language_model."
in
name
:
name
=
name
.
replace
(
"language_model."
,
""
)
if
"model.visual."
in
name
:
name
=
name
.
replace
(
"model.visual."
,
"visual."
)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
"language_model"
in
name
:
name
=
name
.
replace
(
r
"model.language_model."
,
r
"model."
)
if
"model.visual."
in
name
:
name
=
name
.
replace
(
"model.visual."
,
"visual."
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
...
@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
del
self
.
lm_head
.
weight
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
EntryClass
=
[
Glm4vForConditionalGeneration
]
EntryClass
=
[
Glm4vForConditionalGeneration
]
python/sglang/srt/models/glm4v_moe.py
View file @
a88b006e
...
@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
)
self
.
visual
=
Glm4vVisionModel
(
self
.
visual
=
Glm4vVisionModel
(
config
.
vision_config
,
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment