Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc5bd45c
Unverified
Commit
bc5bd45c
authored
Nov 12, 2025
by
Canlin Guo
Committed by
GitHub
Nov 12, 2025
Browse files
[Refactor] Remove redundant TP gather/split in split_qkv in QwenVL (#28271)
Signed-off-by:
gcanlin
<
canlinguosdu@gmail.com
>
parent
f76e85c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
42 deletions
+1
-42
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+0
-30
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-12
No files found.
vllm/model_executor/models/qwen2_5_vl.py
View file @
bc5bd45c
...
@@ -291,25 +291,6 @@ class Qwen2_5_VisionMLP(nn.Module):
...
@@ -291,25 +291,6 @@ class Qwen2_5_VisionMLP(nn.Module):
return
x_down
return
x_down
def
all_gather_interleave
(
local_tensor
,
hidden_size
:
int
,
tp_size
:
int
):
"""All-gather the input tensor interleavely across model parallel group."""
import
torch.distributed
as
dist
gathered_tensors
=
[
torch
.
zeros_like
(
local_tensor
)
for
_
in
range
(
tp_size
)]
dist
.
all_gather
(
gathered_tensors
,
local_tensor
,
group
=
parallel_state
.
get_tp_group
().
device_group
)
gathered_tensors_split
=
[
torch
.
split
(
tensor
,
hidden_size
//
tp_size
,
-
1
)
for
tensor
in
gathered_tensors
]
ordered_tensors
=
[
tensor
for
pair
in
zip
(
*
gathered_tensors_split
)
for
tensor
in
pair
]
result_tensor
=
torch
.
cat
(
ordered_tensors
,
dim
=-
1
)
return
result_tensor
class
Qwen2_5_VisionAttention
(
nn
.
Module
):
class
Qwen2_5_VisionAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -383,21 +364,10 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -383,21 +364,10 @@ class Qwen2_5_VisionAttention(nn.Module):
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
seq_len
,
bs
,
_
=
qkv
.
shape
if
self
.
tp_size
>
1
:
qkv
=
all_gather_interleave
(
qkv
,
self
.
qkv
.
hidden_size
,
self
.
tp_size
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
# 3 * [s, b, head * head_dim]
if
self
.
tp_size
>
1
:
splitter
=
partial
(
dist_utils
.
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
v
=
splitter
(
v
)[
self
.
tp_rank
]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape
=
(
new_shape
=
(
seq_len
,
seq_len
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
bc5bd45c
...
@@ -50,7 +50,7 @@ from vllm.attention.layer import (
...
@@ -50,7 +50,7 @@ from vllm.attention.layer import (
)
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
QuickGELU
from
vllm.model_executor.layers.activation
import
QuickGELU
...
@@ -396,21 +396,10 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -396,21 +396,10 @@ class Qwen2VisionAttention(nn.Module):
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
seq_len
,
bs
,
_
=
qkv
.
shape
if
self
.
tp_size
>
1
:
qkv
=
tensor_model_parallel_all_gather
(
qkv
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
# 3 * [s, b, head * head_dim]
if
self
.
tp_size
>
1
:
splitter
=
partial
(
dist_utils
.
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
v
=
splitter
(
v
)[
self
.
tp_rank
]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape
=
(
new_shape
=
(
seq_len
,
seq_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment