Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d411df02
Unverified
Commit
d411df02
authored
Aug 10, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 10, 2025
Browse files
[Misc] Further refine type annotations in parallel state (#22499)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
010e0e39
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
20 deletions
+19
-20
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+0
-3
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+19
-17
No files found.
vllm/distributed/eplb/eplb_state.py
View file @
d411df02
...
...
@@ -259,7 +259,6 @@ class EplbState:
if
global_expert_load
is
not
None
:
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
assert
global_expert_load
.
shape
==
(
model
.
num_moe_layers
,
model
.
num_logical_experts
)
assert
global_expert_load
.
dtype
==
torch
.
int64
...
...
@@ -366,7 +365,6 @@ class EplbState:
# Collect load metrics from all ranks
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
all_reduce
(
total_expert_load_pass
,
group
=
ep_group
)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
...
...
@@ -422,7 +420,6 @@ class EplbState:
"""
ep_group
=
get_ep_group
().
device_group
assert
ep_group
is
not
None
ep_rank
=
ep_group
.
rank
()
time_start
=
None
...
...
vllm/distributed/parallel_state.py
View file @
d411df02
...
...
@@ -197,11 +197,10 @@ class GroupCoordinator:
# 3 | 1 | 3 | 1 | 3
local_rank
:
int
# local rank used to assign devices
rank_in_group
:
int
# rank inside the group
cpu_group
:
Optional
[
ProcessGroup
]
# group for CPU communication
device_group
:
Optional
[
ProcessGroup
]
# group for device communication
use_device_communicator
:
bool
# whether to use device communicator
device_communicator
:
Optional
[
DeviceCommunicatorBase
]
# device communicator
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
# device communicator (if use_device_communicator=True)
device_communicator
:
Optional
[
DeviceCommunicatorBase
]
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
...
...
@@ -209,7 +208,7 @@ class GroupCoordinator:
group_ranks
:
list
[
list
[
int
]],
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_device_communicator
:
bool
,
use_device_communicator
:
bool
,
# whether to use device communicator
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
...
...
@@ -219,8 +218,9 @@ class GroupCoordinator:
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
self
.
device_group
=
None
self
.
cpu_group
=
None
self_device_group
=
None
self_cpu_group
=
None
for
ranks
in
group_ranks
:
device_group
=
torch
.
distributed
.
new_group
(
...
...
@@ -232,11 +232,14 @@ class GroupCoordinator:
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
self
.
device_group
=
device_group
self
.
cpu_group
=
cpu_group
self_device_group
=
device_group
self_cpu_group
=
cpu_group
assert
self_cpu_group
is
not
None
assert
self_device_group
is
not
None
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
self
.
cpu_group
=
self_cpu_group
self
.
device_group
=
self
_
device_group
from
vllm.platforms
import
current_platform
...
...
@@ -251,7 +254,6 @@ class GroupCoordinator:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_device_communicator
=
use_device_communicator
self
.
device_communicator
=
None
if
use_device_communicator
and
self
.
world_size
>
1
:
device_comm_cls
=
resolve_obj_by_qualname
(
...
...
@@ -817,12 +819,12 @@ class GroupCoordinator:
return
self
.
device_communicator
.
recv
(
size
,
dtype
,
src
)
def
destroy
(
self
):
if
self
.
device_group
is
not
None
:
if
hasattr
(
self
,
"
device_group
"
)
:
torch
.
distributed
.
destroy_process_group
(
self
.
device_group
)
self
.
device_group
=
None
if
self
.
cpu_group
is
not
None
:
del
self
.
device_group
if
hasattr
(
self
,
"
cpu_group
"
)
:
torch
.
distributed
.
destroy_process_group
(
self
.
cpu_group
)
self
.
cpu_group
=
None
del
self
.
cpu_group
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
destroy
()
if
self
.
mq_broadcaster
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment