Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
74333ae2
Unverified
Commit
74333ae2
authored
Aug 05, 2025
by
Ning Xie
Committed by
GitHub
Aug 05, 2025
Browse files
[Misc] correct static type check for GroupCoordinator (#21946)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
83156c7b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
4 deletions
+29
-4
vllm/distributed/device_communicators/ray_communicator.py
vllm/distributed/device_communicators/ray_communicator.py
+1
-0
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+3
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+25
-4
No files found.
vllm/distributed/device_communicators/ray_communicator.py
View file @
74333ae2
...
...
@@ -70,6 +70,7 @@ class RayPPCommunicator(Communicator):
assert
ray
.
get_gpu_ids
(),
"RayPPCommunicator has no GPUs assigned"
self
.
_comm
=
get_pp_group
().
device_communicator
assert
self
.
_comm
is
not
None
# Since we wrap around the vLLM _PP communicator, we use
# the rank from the vLLM communicator, and ignore the rank
...
...
vllm/distributed/eplb/eplb_state.py
View file @
74333ae2
...
...
@@ -251,6 +251,7 @@ class EplbState:
if
global_expert_load
is
not
None
:
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
assert
global_expert_load
.
shape
==
(
model
.
num_moe_layers
,
model
.
num_logical_experts
)
assert
global_expert_load
.
dtype
==
torch
.
int64
...
...
@@ -357,6 +358,7 @@ class EplbState:
# Collect load metrics from all ranks
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
num_tokens_list
=
[
torch
.
empty_like
(
num_tokens
)
for
_
in
range
(
ep_group
.
size
())
]
...
...
@@ -412,6 +414,7 @@ class EplbState:
"""
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
ep_rank
=
ep_group
.
rank
()
time_start
=
None
...
...
vllm/distributed/parallel_state.py
View file @
74333ae2
...
...
@@ -196,10 +196,11 @@ class GroupCoordinator:
# 3 | 1 | 3 | 1 | 3
local_rank
:
int
# local rank used to assign devices
rank_in_group
:
int
# rank inside the group
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
cpu_group
:
Optional
[
ProcessGroup
]
# group for CPU communication
device_group
:
Optional
[
ProcessGroup
]
# group for device communication
use_device_communicator
:
bool
# whether to use device communicator
device_communicator
:
DeviceCommunicatorBase
# device communicator
device_communicator
:
Optional
[
DeviceCommunicatorBase
]
# device communicator
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
...
...
@@ -250,7 +251,7 @@ class GroupCoordinator:
self
.
use_device_communicator
=
use_device_communicator
self
.
device_communicator
:
DeviceCommunicatorBase
=
None
# type: ignor
e
self
.
device_communicator
=
Non
e
if
use_device_communicator
and
self
.
world_size
>
1
:
device_comm_cls
=
resolve_obj_by_qualname
(
current_platform
.
get_device_communicator_cls
())
...
...
@@ -364,6 +365,8 @@ class GroupCoordinator:
return
self
.
_all_reduce_out_place
(
input_
)
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
all_reduce
(
input_
)
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
...
...
@@ -384,12 +387,16 @@ class GroupCoordinator:
def
_all_gather_out_place
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
def
all_gatherv
(
self
,
input_
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
dim
:
int
=
0
,
sizes
:
Optional
[
list
[
int
]]
=
None
):
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
all_gatherv
(
input_
,
dim
,
sizes
)
def
reduce_scatter
(
self
,
...
...
@@ -414,10 +421,14 @@ class GroupCoordinator:
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
sizes
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
reduce_scatterv
(
input_
,
dim
,
sizes
)
def
_reduce_scatter_out_place
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
reduce_scatter
(
input_
,
dim
)
def
gather
(
self
,
...
...
@@ -433,6 +444,8 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
gather
(
input_
,
dst
,
dim
)
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
...
...
@@ -667,6 +680,8 @@ class GroupCoordinator:
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
if
self
.
use_cpu_custom_send_recv
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
self
.
device_communicator
.
send_tensor_dict
(
# type: ignore
tensor_dict
,
dst
)
return
None
...
...
@@ -727,6 +742,8 @@ class GroupCoordinator:
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
if
self
.
use_cpu_custom_send_recv
:
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
recv_tensor_dict
(
# type: ignore
src
)
...
...
@@ -784,6 +801,8 @@ class GroupCoordinator:
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
self
.
device_communicator
.
send
(
tensor
,
dst
)
def
recv
(
self
,
...
...
@@ -792,6 +811,8 @@ class GroupCoordinator:
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
self
.
device_communicator
is
None
:
raise
ValueError
(
"No device communicator found"
)
return
self
.
device_communicator
.
recv
(
size
,
dtype
,
src
)
def
destroy
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment