Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d16aa3da
Unverified
Commit
d16aa3da
authored
Aug 13, 2025
by
zzh142857
Committed by
GitHub
Aug 13, 2025
Browse files
[Model] Add option to run Step3VisionEncoder in DP (#22697)
Signed-off-by:
zzh142857
<
chaorenzhaozhenghao@gmail.com
>
parent
6807af8f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
91 additions
and
41 deletions
+91
-41
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+91
-41
No files found.
vllm/model_executor/models/step3_vl.py
View file @
d16aa3da
...
...
@@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
...
@@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
...
@@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
self
.
scale
=
self
.
head_dim
**-
0.5
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
(
1
if
use_data_parallel
else
get_tensor_model_parallel_world_size
())
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
if
use_data_parallel
:
self
.
qkv_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
3
*
self
.
q_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
self
.
out_proj
=
ReplicatedLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
else
:
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
...
...
@@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
cls_fc1
=
(
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
)
self
.
fc1
=
cls_fc1
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
cls_fc2
=
(
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
)
self
.
fc2
=
cls_fc2
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
@@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
self
.
use_data_parallel
=
use_data_parallel
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
Step3VisionAttention
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
self_attn
=
Step3VisionAttention
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
use_data_parallel
=
self
.
use_data_parallel
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
Step3VisionMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
mlp
=
Step3VisionMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
self
.
use_data_parallel
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
use_data_parallel
=
use_data_parallel
self
.
layers
=
nn
.
ModuleList
([
Step3VisionEncoderLayer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
)
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
,
use_data_parallel
=
self
.
use_data_parallel
)
for
i
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
use_data_parallel
=
use_data_parallel
self
.
image_size
=
config
.
image_size
self
.
embeddings
=
Step3VisionEmbeddings
(
config
)
self
.
transformer
=
Step3VisionEncoder
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.transformer"
)
self
.
transformer
=
Step3VisionEncoder
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.transformer"
,
use_data_parallel
=
self
.
use_data_parallel
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
):
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
transformer
(
inputs_embeds
=
hidden_states
)
if
self
.
use_data_parallel
:
hidden_states
=
run_dp_sharded_vision_model
(
hidden_states
,
self
.
transformer
)
else
:
hidden_states
=
self
.
transformer
(
inputs_embeds
=
hidden_states
)
return
hidden_states
...
...
@@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
(
vllm_config
.
parallel_config
.
enable_multimodal_encoder_data_parallel
)
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
self
.
vision_model
=
Step3VisionTransformer
(
config
.
vision_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
)
)
self
.
vision_model
=
Step3VisionTransformer
(
config
.
vision_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
)
,
use_data_parallel
=
self
.
use_data_parallel
)
self
.
vit_downsampler
=
nn
.
Conv2d
(
config
.
vision_config
.
hidden_size
,
config
.
vision_config
.
output_hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment