Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f0bab3f
Unverified
Commit
2f0bab3f
authored
Sep 02, 2025
by
WeiQing Chen
Committed by
GitHub
Sep 02, 2025
Browse files
[Model] Support dp on ViT on GLM-4.5V (#23168)
Signed-off-by:
David Chen
<
530634352@qq.com
>
parent
fad73be1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
59 deletions
+145
-59
docs/configuration/optimization.md
docs/configuration/optimization.md
+1
-0
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+144
-59
No files found.
docs/configuration/optimization.md
View file @
2f0bab3f
...
@@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
...
@@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
Known supported models:
Known supported models:
-
GLM-4.5V GLM-4.1V (
<gh-pr:23168>
)
-
Kimi-VL (
<gh-pr:23817>
)
-
Kimi-VL (
<gh-pr:23817>
)
-
Llama4 (
<gh-pr:18368>
)
-
Llama4 (
<gh-pr:18368>
)
-
MiniCPM-V-2.5 or above (
<gh-pr:23327>
,
<gh-pr:23948>
)
-
MiniCPM-V-2.5 or above (
<gh-pr:23327>
,
<gh-pr:23948>
)
...
...
vllm/model_executor/models/glm4_1v.py
View file @
2f0bab3f
...
@@ -45,15 +45,20 @@ from transformers.models.glm4v.video_processing_glm4v import (
...
@@ -45,15 +45,20 @@ from transformers.models.glm4v.video_processing_glm4v import (
from
transformers.video_utils
import
VideoMetadata
from
transformers.video_utils
import
VideoMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
parallel_state
)
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedReplicatedLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
@@ -66,6 +71,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -66,6 +71,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
...
@@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
...
@@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
Glm4vVideoInputs
=
Union
[
Glm4vVideoPixelInputs
,
Glm4vVideoEmbeddingInputs
]
Glm4vVideoInputs
=
Union
[
Glm4vVideoPixelInputs
,
Glm4vVideoEmbeddingInputs
]
# === Vision Encoder === #
# ===
=
Vision Encoder ===
=
#
class
Glm4vVisionMLP
(
nn
.
Module
):
class
Glm4vVisionMLP
(
nn
.
Module
):
...
@@ -165,19 +171,23 @@ class Glm4vVisionMLP(nn.Module):
...
@@ -165,19 +171,23 @@ class Glm4vVisionMLP(nn.Module):
bias
:
bool
=
False
,
bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
cls_gate_up
=
(
MergedReplicatedLinear
input_size
=
in_features
,
if
use_data_parallel
else
MergedColumnParallelLinear
)
output_sizes
=
[
hidden_features
]
*
2
,
self
.
gate_up_proj
=
cls_gate_up
(
input_size
=
in_features
,
bias
=
bias
,
output_sizes
=
[
hidden_features
]
*
2
,
quant_config
=
quant_config
,
bias
=
bias
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
quant_config
=
quant_config
,
self
.
down_proj
=
RowParallelLinear
(
hidden_features
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
in_features
,
cls_down
=
(
ReplicatedLinear
bias
=
bias
,
if
use_data_parallel
else
RowParallelLinear
)
quant_config
=
quant_config
,
self
.
down_proj
=
cls_down
(
hidden_features
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
in_features
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
@@ -218,33 +228,54 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -218,33 +228,54 @@ class Glm4vVisionAttention(nn.Module):
projection_size
:
int
,
projection_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Per attention head and per partition values.
# Per attention head and per partition values.
self
.
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
(
1
if
use_data_parallel
else
get_tensor_model_parallel_world_size
())
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
self
.
tp_size
)
num_heads
,
self
.
tp_size
)
self
.
qkv
=
QKVParallelLinear
(
if
use_data_parallel
:
hidden_size
=
embed_dim
,
self
.
qkv
=
ReplicatedLinear
(
head_size
=
self
.
hidden_size_per_attention_head
,
input_size
=
embed_dim
,
total_num_heads
=
num_heads
,
output_size
=
3
*
projection_size
,
total_num_kv_heads
=
num_heads
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
# Change qkv prefix to align with GLM-4.5V-FP8 quantization config
prefix
=
f
"
{
prefix
}
.qkv_proj"
prefix
=
f
"
{
prefix
}
.qkv_proj"
if
quant_config
else
f
"
{
prefix
}
.qkv"
,
if
quant_config
else
f
"
{
prefix
}
.qkv"
,
)
)
self
.
proj
=
RowParallelLinear
(
self
.
proj
=
ReplicatedLinear
(
input_size
=
projection_size
,
input_size
=
projection_size
,
output_size
=
embed_dim
,
output_size
=
embed_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
prefix
=
f
"
{
prefix
}
.proj"
,
bias
=
False
,
bias
=
False
,
)
)
else
:
self
.
qkv
=
QKVParallelLinear
(
hidden_size
=
embed_dim
,
head_size
=
self
.
hidden_size_per_attention_head
,
total_num_heads
=
num_heads
,
total_num_kv_heads
=
num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix
=
f
"
{
prefix
}
.qkv_proj"
if
quant_config
else
f
"
{
prefix
}
.qkv"
,
)
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
output_size
=
embed_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
bias
=
False
,
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
...
@@ -375,6 +406,7 @@ class Glm4vVisionBlock(nn.Module):
...
@@ -375,6 +406,7 @@ class Glm4vVisionBlock(nn.Module):
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
...
@@ -387,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
...
@@ -387,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
projection_size
=
dim
,
projection_size
=
dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
,
)
)
self
.
mlp
=
Glm4vVisionMLP
(
self
.
mlp
=
Glm4vVisionMLP
(
dim
,
dim
,
...
@@ -394,6 +427,7 @@ class Glm4vVisionBlock(nn.Module):
...
@@ -394,6 +427,7 @@ class Glm4vVisionBlock(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
)
def
forward
(
def
forward
(
...
@@ -456,24 +490,40 @@ class Glm4vPatchMerger(nn.Module):
...
@@ -456,24 +490,40 @@ class Glm4vPatchMerger(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
d_model
self
.
hidden_size
=
d_model
self
.
proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
if
use_data_parallel
:
self
.
hidden_size
,
self
.
proj
=
ReplicatedLinear
(
bias
=
bias
,
input_size
=
self
.
hidden_size
,
gather_output
=
True
,
output_size
=
self
.
hidden_size
,
quant_config
=
quant_config
,
bias
=
bias
,
prefix
=
f
"
{
prefix
}
.proj"
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
else
:
self
.
proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
bias
,
gather_output
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
self
.
post_projection_norm
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
post_projection_norm
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
cls_gate_up
=
(
MergedReplicatedLinear
if
use_data_parallel
else
MergedColumnParallelLinear
)
self
.
gate_up_proj
=
cls_gate_up
(
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_sizes
=
[
context_dim
]
*
2
,
output_sizes
=
[
context_dim
]
*
2
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
)
self
.
down_proj
=
RowParallelLinear
(
cls_down
=
(
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
)
self
.
down_proj
=
cls_down
(
context_dim
,
context_dim
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
bias
,
bias
=
bias
,
...
@@ -548,14 +598,33 @@ class Glm4vVisionEmbeddings(nn.Module):
...
@@ -548,14 +598,33 @@ class Glm4vVisionEmbeddings(nn.Module):
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
# Calculate target dimensions for each patch
# Calculate target dimensions for each patch
target_h
=
torch
.
cat
([
# Add bounds checking for data parallel mode
image_shapes
[
i
,
1
].
repeat
(
lengths
[
i
])
if
len
(
lengths
)
>
image_shapes
.
shape
[
0
]:
for
i
in
range
(
len
(
lengths
))
# In data parallel mode, some GPUs might not have all
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
# image shapes
target_w
=
torch
.
cat
([
# Use available image shapes, cycling if necessary
image_shapes
[
i
,
2
].
repeat
(
lengths
[
i
])
target_h_list
=
[]
for
i
in
range
(
len
(
lengths
))
target_w_list
=
[]
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
for
i
in
range
(
len
(
lengths
)):
# Cycle through available shapes
shape_idx
=
i
%
image_shapes
.
shape
[
0
]
target_h_list
.
append
(
image_shapes
[
shape_idx
,
1
].
repeat
(
lengths
[
i
]))
target_w_list
.
append
(
image_shapes
[
shape_idx
,
2
].
repeat
(
lengths
[
i
]))
target_h
=
torch
.
cat
(
target_h_list
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
target_w
=
torch
.
cat
(
target_w_list
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
else
:
target_h
=
torch
.
cat
([
image_shapes
[
i
,
1
].
repeat
(
lengths
[
i
])
for
i
in
range
(
len
(
lengths
))
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
target_w
=
torch
.
cat
([
image_shapes
[
i
,
2
].
repeat
(
lengths
[
i
])
for
i
in
range
(
len
(
lengths
))
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
# Normalize coordinates to [-1, 1] range for grid_sample
# Normalize coordinates to [-1, 1] range for grid_sample
h_coords
=
h_coords
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
h_coords
=
h_coords
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
...
@@ -629,6 +698,7 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -629,6 +698,7 @@ class Glm4vVisionTransformer(nn.Module):
norm_eps
:
float
=
1e-6
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -638,6 +708,7 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -638,6 +708,7 @@ class Glm4vVisionTransformer(nn.Module):
depth
=
vision_config
.
depth
depth
=
vision_config
.
depth
self
.
hidden_size
=
vision_config
.
hidden_size
self
.
hidden_size
=
vision_config
.
hidden_size
self
.
num_heads
=
vision_config
.
num_heads
self
.
num_heads
=
vision_config
.
num_heads
self
.
use_data_parallel
=
use_data_parallel
self
.
patch_size
=
vision_config
.
patch_size
self
.
patch_size
=
vision_config
.
patch_size
self
.
spatial_merge_size
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_size
=
vision_config
.
spatial_merge_size
...
@@ -661,6 +732,7 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -661,6 +732,7 @@ class Glm4vVisionTransformer(nn.Module):
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
self
.
use_data_parallel
,
)
for
layer_idx
in
range
(
depth
)
)
for
layer_idx
in
range
(
depth
)
])
])
self
.
merger
=
Glm4vPatchMerger
(
self
.
merger
=
Glm4vPatchMerger
(
...
@@ -669,6 +741,7 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -669,6 +741,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
False
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
use_data_parallel
=
self
.
use_data_parallel
,
)
)
self
.
embeddings
=
Glm4vVisionEmbeddings
(
vision_config
)
self
.
embeddings
=
Glm4vVisionEmbeddings
(
vision_config
)
...
@@ -731,8 +804,11 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -731,8 +804,11 @@ class Glm4vVisionTransformer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
,
grid_thw
:
list
[
list
[
int
]]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw
=
torch
.
tensor
(
grid_thw
,
device
=
x
.
device
,
dtype
=
torch
.
long
)
# patchify
# patchify
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
self
.
patch_embed
(
x
)
x
=
self
.
patch_embed
(
x
)
...
@@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.visual."
:
"visual."
,
"model.visual."
:
"visual."
,
})
})
supports_encoder_tp_data
=
True
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
if
modality
.
startswith
(
"image"
):
if
modality
.
startswith
(
"image"
):
...
@@ -1267,12 +1345,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1267,12 +1345,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
visual
=
Glm4vVisionTransformer
(
self
.
visual
=
Glm4vVisionTransformer
(
config
.
vision_config
,
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
)
if
config
.
model_type
==
"glm4v"
:
if
config
.
model_type
==
"glm4v"
:
...
@@ -1382,8 +1462,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1382,8 +1462,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds
=
image_input
[
"image_embeds"
].
type
(
self
.
visual
.
dtype
)
image_embeds
=
image_input
[
"image_embeds"
].
type
(
self
.
visual
.
dtype
)
else
:
else
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw
)
if
self
.
use_data_parallel
:
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values
,
grid_thw
.
tolist
(),
rope_type
=
"rope_3d"
)
else
:
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw
.
tolist
())
merge_size
=
self
.
visual
.
spatial_merge_size
merge_size
=
self
.
visual
.
spatial_merge_size
sizes
=
grid_thw
.
prod
(
-
1
)
//
merge_size
//
merge_size
sizes
=
grid_thw
.
prod
(
-
1
)
//
merge_size
//
merge_size
return
image_embeds
.
split
(
sizes
.
tolist
())
return
image_embeds
.
split
(
sizes
.
tolist
())
...
@@ -1393,23 +1479,22 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1393,23 +1479,22 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw
=
video_input
[
"video_grid_thw"
]
grid_thw
=
video_input
[
"video_grid_thw"
]
assert
grid_thw
.
ndim
==
2
assert
grid_thw
.
ndim
==
2
device
=
self
.
visual
.
device
flat_grid_thw
=
torch
.
cat
([
torch
.
tensor
([[
1
,
h
,
w
]]
*
t
,
device
=
device
)
for
t
,
h
,
w
in
grid_thw
])
if
video_input
[
"type"
]
==
"video_embeds"
:
if
video_input
[
"type"
]
==
"video_embeds"
:
video_embeds
=
video_input
[
"video_embeds"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
video_input
[
"video_embeds"
].
type
(
self
.
visual
.
dtype
)
else
:
else
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
pixel_values_videos
,
if
self
.
use_data_parallel
:
grid_thw
=
flat_grid_thw
)
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values_videos
,
grid_thw
.
tolist
(),
rope_type
=
"rope_3d"
)
else
:
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
grid_thw
.
tolist
())
# Split concatenated embeddings for each video item.
# Split concatenated embeddings for each video item.
merge_size
=
self
.
visual
.
spatial_merge_size
merge_size
=
self
.
visual
.
spatial_merge_size
sizes
=
grid_thw
.
prod
(
-
1
)
//
merge_size
//
merge_size
sizes
=
grid_thw
.
prod
(
-
1
)
//
merge_size
//
merge_size
return
video_embeds
.
split
(
sizes
.
tolist
())
return
video_embeds
.
split
(
sizes
.
tolist
())
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment