Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
00da9065
Unverified
Commit
00da9065
authored
Aug 03, 2025
by
yhyang201
Committed by
GitHub
Aug 03, 2025
Browse files
feat: Support DP Attention for step3_vl (#8699)
parent
8cd34458
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
6 deletions
+25
-6
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+13
-5
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+9
-0
python/sglang/srt/multimodal/processors/step3_vl.py
python/sglang/srt/multimodal/processors/step3_vl.py
+3
-1
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
00da9065
...
...
@@ -11,6 +11,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.utils
import
is_cuda
,
print_info_once
_is_cuda
=
is_cuda
()
...
...
@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
**
kwargs
,
):
super
().
__init__
()
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
world_size
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
self
.
tp_size
=
attn_tp_size
self
.
tp_rank
=
attn_tp_rank
self
.
dropout
=
dropout
self
.
head_size
=
embed_dim
//
num_heads
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_dummy_heads
+
num_heads
,
world
_size
num_dummy_heads
+
num_heads
,
self
.
tp
_size
)
self
.
num_attention_kv_heads_per_partition
=
dist_utils
.
divide
(
num_dummy_heads
+
num_heads
,
world
_size
num_dummy_heads
+
num_heads
,
self
.
tp
_size
)
self
.
q_size
=
self
.
num_attention_heads_per_partition
*
self
.
head_size
...
...
@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
total_num_kv_heads
=
num_dummy_heads
+
num_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
else
:
...
...
@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
output_size
=
3
*
self
.
dummy_dim
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
proj
=
RowParallelLinear
(
...
...
@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
output_size
=
embed_dim
,
bias
=
proj_bias
,
quant_config
=
quant_config
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
prefix
=
add_prefix
(
"proj"
,
prefix
),
)
...
...
python/sglang/srt/models/step3_vl.py
View file @
00da9065
...
...
@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# Since this is a dense model,
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
# This choice may not represent the optimal solution and remains open to further deliberation.
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
self
.
fc1
=
ColumnParallelLinear
(
dim
,
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
prefix
=
add_prefix
(
"gate_proj"
,
prefix
),
)
self
.
act
=
ACT2FN
[
hidden_act
]
# quick_gelu
...
...
@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
dim
,
bias
=
bias
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
)
...
...
python/sglang/srt/multimodal/processors/step3_vl.py
View file @
00da9065
...
...
@@ -8,7 +8,7 @@ import torch
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
BatchFeature
,
TensorType
from
transformers
import
BatchFeature
,
ProcessorMixin
,
TensorType
from
sglang.srt.models.step3_vl
import
Step3VLForConditionalGeneration
from
sglang.srt.multimodal.processors.base_processor
import
(
...
...
@@ -276,6 +276,8 @@ class Step3VLProcessor:
super
().
__init__
()
self
.
config
=
config
if
isinstance
(
tokenizer
,
ProcessorMixin
):
tokenizer
=
tokenizer
.
tokenizer
self
.
tokenizer
=
tokenizer
self
.
image_size
=
728
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment