Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
484e22bc
Unverified
Commit
484e22bc
authored
Jan 16, 2026
by
Chenyaaang
Committed by
GitHub
Jan 16, 2026
Browse files
[TPU][Core] Enable Pipeline Parallelism on TPU backend (#28506)
Signed-off-by:
Chenyaaang
<
chenyangli@google.com
>
parent
ca212880
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
14 deletions
+34
-14
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+30
-13
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+4
-1
No files found.
vllm/v1/executor/multiproc_executor.py
View file @
484e22bc
...
@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
...
@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
self
.
shutdown_event
=
threading
.
Event
()
self
.
shutdown_event
=
threading
.
Event
()
self
.
failure_callback
:
FailureCallback
|
None
=
None
self
.
failure_callback
:
FailureCallback
|
None
=
None
self
.
world_size
=
self
.
parallel_config
.
world_size
tp_size
,
pp_size
,
pcp_size
=
self
.
_get_parallel_sizes
()
assert
self
.
world_size
%
self
.
parallel_config
.
nnodes_within_dp
==
0
,
(
f
"global world_size (
{
self
.
parallel_config
.
world_size
}
) must be "
f
"divisible by nnodes_within_dp "
f
"(
{
self
.
parallel_config
.
nnodes_within_dp
}
). "
)
self
.
local_world_size
=
self
.
parallel_config
.
local_world_size
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
pcp_size
=
self
.
parallel_config
.
prefill_context_parallel_size
assert
self
.
world_size
==
tp_size
*
pp_size
*
pcp_size
,
(
assert
self
.
world_size
==
tp_size
*
pp_size
*
pcp_size
,
(
f
"world_size (
{
self
.
world_size
}
) must be equal to the "
f
"world_size (
{
self
.
world_size
}
) must be equal to the "
f
"tensor_parallel_size (
{
tp_size
}
) x pipeline"
f
"tensor_parallel_size (
{
tp_size
}
) x pipeline"
...
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
...
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
)
)
for
local_rank
in
range
(
self
.
local_world_size
):
for
local_rank
in
range
(
self
.
local_world_size
):
global_rank
=
global_start_rank
+
local_rank
global_rank
=
global_start_rank
+
local_rank
is_driver_worker
=
self
.
_is_driver_worker
(
global_rank
)
unready_workers
.
append
(
unready_workers
.
append
(
WorkerProc
.
make_worker_process
(
WorkerProc
.
make_worker_process
(
vllm_config
=
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
...
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
...
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
input_shm_handle
=
scheduler_output_handle
,
input_shm_handle
=
scheduler_output_handle
,
shared_worker_lock
=
shared_worker_lock
,
shared_worker_lock
=
shared_worker_lock
,
is_driver_worker
=
is_driver_worker
,
)
)
)
)
...
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
...
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
# Wait for all remote response mqs to be ready.
# Wait for all remote response mqs to be ready.
for
response_mq
in
self
.
response_mqs
:
for
response_mq
in
self
.
response_mqs
:
response_mq
.
wait_until_ready
()
response_mq
.
wait_until_ready
()
self
.
futures_queue
=
deque
[
tuple
[
FutureWrapper
,
Callable
]]()
self
.
_post_init_executor
()
success
=
True
success
=
True
finally
:
finally
:
if
not
success
:
if
not
success
:
...
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
...
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
uw
.
death_writer
.
close
()
uw
.
death_writer
.
close
()
self
.
_ensure_worker_termination
([
uw
.
proc
for
uw
in
unready_workers
])
self
.
_ensure_worker_termination
([
uw
.
proc
for
uw
in
unready_workers
])
self
.
futures_queue
=
deque
[
tuple
[
FutureWrapper
,
Callable
]]()
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
output_rank
=
self
.
_get_output_rank
()
def
_get_parallel_sizes
(
self
)
->
tuple
[
int
,
int
,
int
]:
self
.
world_size
=
self
.
parallel_config
.
world_size
assert
self
.
world_size
%
self
.
parallel_config
.
nnodes_within_dp
==
0
,
(
f
"global world_size (
{
self
.
parallel_config
.
world_size
}
) must be "
f
"divisible by nnodes_within_dp "
f
"(
{
self
.
parallel_config
.
nnodes_within_dp
}
). "
)
self
.
local_world_size
=
self
.
parallel_config
.
local_world_size
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
pcp_size
=
self
.
parallel_config
.
prefill_context_parallel_size
return
tp_size
,
pp_size
,
pcp_size
def
_post_init_executor
(
self
)
->
None
:
pass
def
_is_driver_worker
(
self
,
rank
:
int
)
->
bool
:
return
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
def
start_worker_monitor
(
self
,
inline
=
False
)
->
None
:
def
start_worker_monitor
(
self
,
inline
=
False
)
->
None
:
workers
=
self
.
workers
workers
=
self
.
workers
self_ref
=
weakref
.
ref
(
self
)
self_ref
=
weakref
.
ref
(
self
)
...
@@ -517,6 +532,7 @@ class WorkerProc:
...
@@ -517,6 +532,7 @@ class WorkerProc:
distributed_init_method
:
str
,
distributed_init_method
:
str
,
input_shm_handle
:
Handle
,
input_shm_handle
:
Handle
,
shared_worker_lock
:
LockType
,
shared_worker_lock
:
LockType
,
is_driver_worker
:
bool
,
):
):
self
.
rank
=
rank
self
.
rank
=
rank
wrapper
=
WorkerWrapperBase
(
rpc_rank
=
local_rank
,
global_rank
=
rank
)
wrapper
=
WorkerWrapperBase
(
rpc_rank
=
local_rank
,
global_rank
=
rank
)
...
@@ -524,7 +540,6 @@ class WorkerProc:
...
@@ -524,7 +540,6 @@ class WorkerProc:
all_kwargs
:
list
[
dict
]
=
[
all_kwargs
:
list
[
dict
]
=
[
{}
for
_
in
range
(
vllm_config
.
parallel_config
.
world_size
)
{}
for
_
in
range
(
vllm_config
.
parallel_config
.
world_size
)
]
]
is_driver_worker
=
rank
%
vllm_config
.
parallel_config
.
tensor_parallel_size
==
0
all_kwargs
[
local_rank
]
=
{
all_kwargs
[
local_rank
]
=
{
"vllm_config"
:
vllm_config
,
"vllm_config"
:
vllm_config
,
"local_rank"
:
local_rank
,
"local_rank"
:
local_rank
,
...
@@ -571,6 +586,7 @@ class WorkerProc:
...
@@ -571,6 +586,7 @@ class WorkerProc:
distributed_init_method
:
str
,
distributed_init_method
:
str
,
input_shm_handle
,
# Receive SchedulerOutput
input_shm_handle
,
# Receive SchedulerOutput
shared_worker_lock
:
LockType
,
shared_worker_lock
:
LockType
,
is_driver_worker
:
bool
,
)
->
UnreadyWorkerProcHandle
:
)
->
UnreadyWorkerProcHandle
:
context
=
get_mp_context
()
context
=
get_mp_context
()
# (reader, writer)
# (reader, writer)
...
@@ -588,6 +604,7 @@ class WorkerProc:
...
@@ -588,6 +604,7 @@ class WorkerProc:
"ready_pipe"
:
(
reader
,
writer
),
"ready_pipe"
:
(
reader
,
writer
),
"death_pipe"
:
death_reader
,
"death_pipe"
:
death_reader
,
"shared_worker_lock"
:
shared_worker_lock
,
"shared_worker_lock"
:
shared_worker_lock
,
"is_driver_worker"
:
is_driver_worker
,
}
}
# Run EngineCore busy loop in background process.
# Run EngineCore busy loop in background process.
proc
=
context
.
Process
(
proc
=
context
.
Process
(
...
...
vllm/v1/executor/ray_utils.py
View file @
484e22bc
...
@@ -103,7 +103,7 @@ try:
...
@@ -103,7 +103,7 @@ try:
output
=
self
.
worker
.
model_runner
.
execute_model
(
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
scheduler_output
,
intermediate_tensors
)
)
if
isinstance
(
output
,
I
ntermediate
T
ensors
):
if
self
.
_is_i
ntermediate
_t
ensors
(
output
):
return
scheduler_output
,
grammar_output
,
output
return
scheduler_output
,
grammar_output
,
output
if
isinstance
(
output
,
AsyncModelRunnerOutput
):
if
isinstance
(
output
,
AsyncModelRunnerOutput
):
...
@@ -125,6 +125,9 @@ try:
...
@@ -125,6 +125,9 @@ try:
def
override_env_vars
(
self
,
vars
:
dict
[
str
,
str
]):
def
override_env_vars
(
self
,
vars
:
dict
[
str
,
str
]):
os
.
environ
.
update
(
vars
)
os
.
environ
.
update
(
vars
)
def
_is_intermediate_tensors
(
self
,
output
)
->
bool
:
return
isinstance
(
output
,
IntermediateTensors
)
ray_import_err
=
None
ray_import_err
=
None
except
ImportError
as
e
:
except
ImportError
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment