Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36735fd7
Unverified
Commit
36735fd7
authored
Mar 11, 2026
by
Nick Hill
Committed by
GitHub
Mar 12, 2026
Browse files
[BugFix] Fix multiple/duplicate stdout prefixes (#36822)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
6ecabe49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
37 deletions
+21
-37
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+0
-2
vllm/utils/system_utils.py
vllm/utils/system_utils.py
+3
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-12
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+14
-22
No files found.
vllm/entrypoints/cli/serve.py
View file @
36735fd7
...
@@ -21,7 +21,6 @@ from vllm.usage.usage_lib import UsageContext
...
@@ -21,7 +21,6 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
get_tcp_uri
from
vllm.utils.network_utils
import
get_tcp_uri
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.utils
import
CoreEngineProcManager
,
launch_core_engines
from
vllm.v1.engine.utils
import
CoreEngineProcManager
,
launch_core_engines
from
vllm.v1.executor
import
Executor
from
vllm.v1.executor
import
Executor
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
...
@@ -210,7 +209,6 @@ def run_headless(args: argparse.Namespace):
...
@@ -210,7 +209,6 @@ def run_headless(args: argparse.Namespace):
# Create the engines.
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
engine_manager
=
CoreEngineProcManager
(
target_fn
=
EngineCoreProc
.
run_engine_core
,
local_engine_count
=
local_engine_count
,
local_engine_count
=
local_engine_count
,
start_index
=
vllm_config
.
parallel_config
.
data_parallel_rank
,
start_index
=
vllm_config
.
parallel_config
.
data_parallel_rank
,
local_start_index
=
0
,
local_start_index
=
0
,
...
...
vllm/utils/system_utils.py
View file @
36735fd7
...
@@ -204,7 +204,8 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
...
@@ -204,7 +204,8 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
prefix
=
f
"(
{
worker_name
}
pid=
{
pid
}
) "
prefix
=
f
"(
{
worker_name
}
pid=
{
pid
}
) "
else
:
else
:
prefix
=
f
"
{
CYAN
}
(
{
worker_name
}
pid=
{
pid
}
)
{
RESET
}
"
prefix
=
f
"
{
CYAN
}
(
{
worker_name
}
pid=
{
pid
}
)
{
RESET
}
"
file_write
=
file
.
write
# Use the original write to avoid nesting prefixes on repeated calls.
file_write
=
getattr
(
file
,
"_original_write"
,
file
.
write
)
def
write_with_prefix
(
s
:
str
):
def
write_with_prefix
(
s
:
str
):
if
not
s
:
if
not
s
:
...
@@ -224,6 +225,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
...
@@ -224,6 +225,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
_original_write
=
file_write
# type: ignore[attr-defined]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
...
...
vllm/v1/engine/core.py
View file @
36735fd7
...
@@ -1045,19 +1045,11 @@ class EngineCoreProc(EngineCore):
...
@@ -1045,19 +1045,11 @@ class EngineCoreProc(EngineCore):
data_parallel
=
parallel_config
.
data_parallel_size
>
1
or
dp_rank
>
0
data_parallel
=
parallel_config
.
data_parallel_size
>
1
or
dp_rank
>
0
if
data_parallel
:
if
data_parallel
:
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
maybe_init_worker_tracer
(
process_title
=
f
"EngineCore_DP
{
dp_rank
}
"
instrumenting_module_name
=
"vllm.engine_core"
,
process_kind
=
"engine_core"
,
process_name
=
f
"EngineCore_DP
{
dp_rank
}
"
,
)
set_process_title
(
"EngineCore"
,
f
"DP
{
dp_rank
}
"
)
else
:
else
:
maybe_init_worker_tracer
(
process_title
=
"EngineCore"
instrumenting_module_name
=
"vllm.engine_core"
,
set_process_title
(
process_title
)
process_kind
=
"engine_core"
,
maybe_init_worker_tracer
(
"vllm.engine_core"
,
"engine_core"
,
process_title
)
process_name
=
"EngineCore"
,
)
set_process_title
(
"EngineCore"
)
decorate_logs
()
decorate_logs
()
if
data_parallel
and
vllm_config
.
kv_transfer_config
is
not
None
:
if
data_parallel
and
vllm_config
.
kv_transfer_config
is
not
None
:
...
...
vllm/v1/engine/utils.py
View file @
36735fd7
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
contextlib
import
contextlib
import
os
import
os
import
weakref
import
weakref
from
collections.abc
import
Callable
,
Iterator
from
collections.abc
import
Iterator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
multiprocessing
import
Process
,
connection
from
multiprocessing
import
Process
,
connection
...
@@ -85,7 +85,6 @@ class CoreEngineProcManager:
...
@@ -85,7 +85,6 @@ class CoreEngineProcManager:
def
__init__
(
def
__init__
(
self
,
self
,
target_fn
:
Callable
,
local_engine_count
:
int
,
local_engine_count
:
int
,
start_index
:
int
,
start_index
:
int
,
local_start_index
:
int
,
local_start_index
:
int
,
...
@@ -108,6 +107,10 @@ class CoreEngineProcManager:
...
@@ -108,6 +107,10 @@ class CoreEngineProcManager:
if
client_handshake_address
:
if
client_handshake_address
:
common_kwargs
[
"client_handshake_address"
]
=
client_handshake_address
common_kwargs
[
"client_handshake_address"
]
=
client_handshake_address
is_dp
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
from
vllm.v1.engine.core
import
EngineCoreProc
self
.
processes
:
list
[
BaseProcess
]
=
[]
self
.
processes
:
list
[
BaseProcess
]
=
[]
local_dp_ranks
=
[]
local_dp_ranks
=
[]
for
index
in
range
(
local_engine_count
):
for
index
in
range
(
local_engine_count
):
...
@@ -118,35 +121,27 @@ class CoreEngineProcManager:
...
@@ -118,35 +121,27 @@ class CoreEngineProcManager:
local_dp_ranks
.
append
(
local_index
)
local_dp_ranks
.
append
(
local_index
)
self
.
processes
.
append
(
self
.
processes
.
append
(
context
.
Process
(
context
.
Process
(
target
=
target_fn
,
target
=
EngineCoreProc
.
run_engine_core
,
name
=
f
"EngineCore_DP
{
global_index
}
"
,
name
=
f
"EngineCore_DP
{
global_index
}
"
if
is_dp
else
"EngineCore"
,
kwargs
=
common_kwargs
kwargs
=
common_kwargs
|
{
|
{
"dp_rank"
:
global_index
,
"local_dp_rank"
:
local_index
},
"dp_rank"
:
global_index
,
"local_dp_rank"
:
local_index
,
},
)
)
)
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
data_parallel
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
try
:
try
:
for
proc
,
local_dp_rank
in
zip
(
self
.
processes
,
local_dp_ranks
):
for
proc
,
local_dp_rank
in
zip
(
self
.
processes
,
local_dp_ranks
):
# Adjust device control in DP for non-CUDA platforms
# Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers
# as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device()
# For CUDA platforms, we use torch.cuda.set_device()
with
(
if
is_dp
and
(
set_device_control_env_var
(
vllm_config
,
local_dp_rank
)
not
current_platform
.
is_cuda_alike
()
if
(
or
vllm_config
.
parallel_config
.
use_ray
data_parallel
and
(
not
current_platform
.
is_cuda_alike
()
or
vllm_config
.
parallel_config
.
use_ray
)
)
else
contextlib
.
nullcontext
()
):
):
with
set_device_control_env_var
(
vllm_config
,
local_dp_rank
):
proc
.
start
()
else
:
proc
.
start
()
proc
.
start
()
finally
:
finally
:
# Kill other procs if not all are running.
# Kill other procs if not all are running.
...
@@ -926,12 +921,9 @@ def launch_core_engines(
...
@@ -926,12 +921,9 @@ def launch_core_engines(
with
zmq_socket_ctx
(
with
zmq_socket_ctx
(
local_handshake_address
,
zmq
.
ROUTER
,
bind
=
True
local_handshake_address
,
zmq
.
ROUTER
,
bind
=
True
)
as
handshake_socket
:
)
as
handshake_socket
:
from
vllm.v1.engine.core
import
EngineCoreProc
# Start local engines.
# Start local engines.
if
local_engine_count
:
if
local_engine_count
:
local_engine_manager
=
CoreEngineProcManager
(
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
log_stats
=
log_stats
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment