Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ebffafb
Unverified
Commit
6ebffafb
authored
Oct 27, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 27, 2025
Browse files
[Misc] Clean up more utils (#27567)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3b96f85c
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
262 additions
and
290 deletions
+262
-290
requirements/docs.txt
requirements/docs.txt
+2
-0
tools/pre_commit/check_pickle_imports.py
tools/pre_commit/check_pickle_imports.py
+0
-1
vllm/config/model.py
vllm/config/model.py
+23
-0
vllm/config/vllm.py
vllm/config/vllm.py
+27
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-2
vllm/entrypoints/anthropic/api_server.py
vllm/entrypoints/anthropic/api_server.py
+1
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+2
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+2
-2
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+1
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-1
vllm/utils/__init__.py
vllm/utils/__init__.py
+4
-242
vllm/utils/argparse_utils.py
vllm/utils/argparse_utils.py
+3
-23
vllm/utils/import_utils.py
vllm/utils/import_utils.py
+43
-0
vllm/utils/system_utils.py
vllm/utils/system_utils.py
+114
-8
vllm/v1/engine/coordinator.py
vllm/v1/engine/coordinator.py
+1
-2
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+1
-1
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+6
-2
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+1
-1
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+28
-0
vllm/v1/utils.py
vllm/v1/utils.py
+1
-1
No files found.
requirements/docs.txt
View file @
6ebffafb
...
...
@@ -13,6 +13,8 @@ ruff
# Required for argparse hook only
-f https://download.pytorch.org/whl/cpu
cachetools
cloudpickle
py-cpuinfo
msgspec
pydantic
torch
tools/pre_commit/check_pickle_imports.py
View file @
6ebffafb
...
...
@@ -39,7 +39,6 @@ ALLOWED_FILES = {
"vllm/v1/executor/multiproc_executor.py"
,
"vllm/v1/executor/ray_executor.py"
,
"vllm/entrypoints/llm.py"
,
"vllm/utils/__init__.py"
,
"tests/utils.py"
,
# pickle and cloudpickle
"vllm/v1/serial_utils.py"
,
...
...
vllm/config/model.py
View file @
6ebffafb
...
...
@@ -1618,6 +1618,29 @@ class ModelConfig:
"""Extract the HF encoder/decoder model flag."""
return
is_encoder_decoder
(
self
.
hf_config
)
@
property
def
uses_alibi
(
self
)
->
bool
:
cfg
=
self
.
hf_text_config
return
(
getattr
(
cfg
,
"alibi"
,
False
)
# Falcon
or
"BloomForCausalLM"
in
self
.
architectures
# Bloom
or
getattr
(
cfg
,
"position_encoding_type"
,
""
)
==
"alibi"
# codellm_1b_alibi
or
(
hasattr
(
cfg
,
"attn_config"
)
# MPT
and
(
(
isinstance
(
cfg
.
attn_config
,
dict
)
and
cfg
.
attn_config
.
get
(
"alibi"
,
False
)
)
or
(
not
isinstance
(
cfg
.
attn_config
,
dict
)
and
getattr
(
cfg
.
attn_config
,
"alibi"
,
False
)
)
)
)
)
@
property
def
uses_mrope
(
self
)
->
bool
:
return
uses_mrope
(
self
.
hf_config
)
...
...
vllm/config/vllm.py
View file @
6ebffafb
...
...
@@ -2,12 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
getpass
import
hashlib
import
json
import
os
import
tempfile
import
threading
import
time
from
contextlib
import
contextmanager
from
dataclasses
import
replace
from
datetime
import
datetime
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
...
...
@@ -17,7 +21,7 @@ from pydantic import ConfigDict, Field
from
pydantic.dataclasses
import
dataclass
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.transformers_utils.runai_utils
import
is_runai_obj_uri
from
vllm.utils
import
random_uuid
...
...
@@ -206,6 +210,28 @@ class VllmConfig:
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return
self
.
compilation_config
.
bs_to_padded_graph_size
[
batch_size
]
def
enable_trace_function_call_for_thread
(
self
)
->
None
:
"""
Set up function tracing for the current thread,
if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
"""
if
envs
.
VLLM_TRACE_FUNCTION
:
tmp_dir
=
tempfile
.
gettempdir
()
# add username to tmp_dir to avoid permission issues
tmp_dir
=
os
.
path
.
join
(
tmp_dir
,
getpass
.
getuser
())
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_at_
{
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
f
"vllm-instance-
{
self
.
instance_id
}
"
,
filename
,
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
@
staticmethod
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
...
...
vllm/engine/arg_utils.py
View file @
6ebffafb
...
...
@@ -73,7 +73,7 @@ from vllm.config.utils import get_field
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
from
vllm.ray.lazy_utils
import
is_ray_initialized
from
vllm.ray.lazy_utils
import
is_in_ray_actor
,
is_ray_initialized
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.test_utils
import
MODEL_WEIGHTS_S3_BUCKET
,
MODELS_ON_S3
from
vllm.transformers_utils.config
import
(
...
...
@@ -82,7 +82,6 @@ from vllm.transformers_utils.config import (
maybe_override_with_speculators
,
)
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
is_in_ray_actor
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.network_utils
import
get_ip
...
...
vllm/entrypoints/anthropic/api_server.py
View file @
6ebffafb
...
...
@@ -51,9 +51,9 @@ from vllm.entrypoints.utils import (
with_cancellation
,
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
set_ulimit
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.system_utils
import
set_ulimit
from
vllm.version
import
__version__
as
VLLM_VERSION
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
...
...
vllm/entrypoints/api_server.py
View file @
6ebffafb
...
...
@@ -26,8 +26,9 @@ from vllm.entrypoints.utils import with_cancellation
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
random_uuid
,
set_ulimit
from
vllm.utils
import
random_uuid
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.system_utils
import
set_ulimit
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
...
vllm/entrypoints/openai/api_server.py
View file @
6ebffafb
...
...
@@ -108,10 +108,10 @@ from vllm.entrypoints.utils import (
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
set_ulimit
from
vllm.utils
import
Device
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.system_utils
import
decorate_logs
from
vllm.utils.system_utils
import
decorate_logs
,
set_ulimit
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
vllm/platforms/__init__.py
View file @
6ebffafb
...
...
@@ -60,7 +60,7 @@ def cuda_platform_plugin() -> str | None:
is_cuda
=
False
logger
.
debug
(
"Checking if CUDA platform is available."
)
try
:
from
vllm.utils
import
import_pynvml
from
vllm.utils
.import_utils
import
import_pynvml
pynvml
=
import_pynvml
()
pynvml
.
nvmlInit
()
...
...
vllm/platforms/cuda.py
View file @
6ebffafb
...
...
@@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
import
vllm._C
# noqa
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
import_pynvml
from
vllm.utils
.import_utils
import
import_pynvml
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
...
...
vllm/utils/__init__.py
View file @
6ebffafb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
datetime
import
enum
import
getpass
import
inspect
import
multiprocessing
import
os
import
signal
import
sys
import
tempfile
import
threading
import
uuid
import
warnings
from
collections.abc
import
Callable
from
functools
import
partial
,
wraps
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
functools
import
wraps
from
typing
import
Any
,
TypeVar
import
cloudpickle
import
psutil
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.ray.lazy_utils
import
is_in_ray_actor
from
vllm.logger
import
init_logger
_DEPRECATED_MAPPINGS
=
{
"cprofile"
:
"profiling"
,
"cprofile_context"
:
"profiling"
,
# Used by lm-eval
"get_open_port"
:
"network_utils"
,
}
...
...
@@ -53,12 +41,6 @@ def __dir__() -> list[str]:
return
sorted
(
list
(
globals
().
keys
())
+
list
(
_DEPRECATED_MAPPINGS
.
keys
()))
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
object
VllmConfig
=
object
logger
=
init_logger
(
__name__
)
# This value is chosen to have a balance between ITL and TTFT. Note it is
...
...
@@ -83,13 +65,7 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
# ANSI color codes
CYAN
=
"
\033
[1;36m"
RESET
=
"
\033
[0;0m"
T
=
TypeVar
(
"T"
)
U
=
TypeVar
(
"U"
)
class
Device
(
enum
.
Enum
):
...
...
@@ -144,195 +120,6 @@ def random_uuid() -> str:
return
str
(
uuid
.
uuid4
().
hex
)
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def
init_cached_hf_modules
()
->
None
:
"""
Lazy initialization of the Hugging Face modules.
"""
from
transformers.dynamic_module_utils
import
init_hf_modules
init_hf_modules
()
def
enable_trace_function_call_for_thread
(
vllm_config
:
VllmConfig
)
->
None
:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if
envs
.
VLLM_TRACE_FUNCTION
:
tmp_dir
=
tempfile
.
gettempdir
()
# add username to tmp_dir to avoid permission issues
tmp_dir
=
os
.
path
.
join
(
tmp_dir
,
getpass
.
getuser
())
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
f
"vllm-instance-
{
vllm_config
.
instance_id
}
"
,
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
def
kill_process_tree
(
pid
:
int
):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try
:
parent
=
psutil
.
Process
(
pid
)
except
psutil
.
NoSuchProcess
:
return
# Get all children recursively
children
=
parent
.
children
(
recursive
=
True
)
# Send SIGKILL to all children first
for
child
in
children
:
with
contextlib
.
suppress
(
ProcessLookupError
):
os
.
kill
(
child
.
pid
,
signal
.
SIGKILL
)
# Finally kill the parent
with
contextlib
.
suppress
(
ProcessLookupError
):
os
.
kill
(
pid
,
signal
.
SIGKILL
)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
def
set_ulimit
(
target_soft_limit
=
65535
):
if
sys
.
platform
.
startswith
(
"win"
):
logger
.
info
(
"Windows detected, skipping ulimit adjustment."
)
return
import
resource
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
logger
.
warning
(
"Found ulimit of %s and failed to automatically increase "
"with error %s. This can cause fd limit errors like "
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n"
,
current_soft
,
e
,
)
def
_maybe_force_spawn
():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
==
"spawn"
:
return
reasons
=
[]
if
is_in_ray_actor
():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import
ray
os
.
environ
[
"RAY_ADDRESS"
]
=
ray
.
get_runtime_context
().
gcs_address
reasons
.
append
(
"In a Ray actor and can only be spawned"
)
from
.platform_utils
import
cuda_is_initialized
,
xpu_is_initialized
if
cuda_is_initialized
():
reasons
.
append
(
"CUDA is initialized"
)
elif
xpu_is_initialized
():
reasons
.
append
(
"XPU is initialized"
)
if
reasons
:
logger
.
warning
(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reasons: %s"
,
"; "
.
join
(
reasons
),
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
def
run_method
(
obj
:
Any
,
method
:
str
|
bytes
|
Callable
,
args
:
tuple
[
Any
],
kwargs
:
dict
[
str
,
Any
],
)
->
Any
:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if
isinstance
(
method
,
bytes
):
func
=
partial
(
cloudpickle
.
loads
(
method
),
obj
)
elif
isinstance
(
method
,
str
):
try
:
func
=
getattr
(
obj
,
method
)
except
AttributeError
:
raise
NotImplementedError
(
f
"Method
{
method
!
r
}
is not implemented."
)
from
None
else
:
func
=
partial
(
method
,
obj
)
# type: ignore
return
func
(
*
args
,
**
kwargs
)
def
import_pynvml
():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
It is so confusing that many packages in the community use the
unofficial one by mistake, and we have to handle this case.
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
one, and it will cause errors, see the issue
https://github.com/vllm-project/vllm/issues/12847 for example.
After all the troubles, we decide to copy the official `pynvml`
module to our codebase, and use it directly.
"""
import
vllm.third_party.pynvml
as
pynvml
return
pynvml
def
warn_for_unimplemented_methods
(
cls
:
type
[
T
])
->
type
[
T
]:
"""
A replacement for `abc.ABC`.
...
...
@@ -376,31 +163,6 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
return
cls
# Only relevant for models using ALiBi (e.g, MPT)
def
check_use_alibi
(
model_config
:
ModelConfig
)
->
bool
:
cfg
=
model_config
.
hf_text_config
return
(
getattr
(
cfg
,
"alibi"
,
False
)
# Falcon
or
(
"BloomForCausalLM"
in
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
)
# Bloom
or
getattr
(
cfg
,
"position_encoding_type"
,
""
)
==
"alibi"
# codellm_1b_alibi
or
(
hasattr
(
cfg
,
"attn_config"
)
# MPT
and
(
(
isinstance
(
cfg
.
attn_config
,
dict
)
and
cfg
.
attn_config
.
get
(
"alibi"
,
False
)
)
or
(
not
isinstance
(
cfg
.
attn_config
,
dict
)
and
getattr
(
cfg
.
attn_config
,
"alibi"
,
False
)
)
)
)
)
def
length_from_prompt_token_ids_or_embeds
(
prompt_token_ids
:
list
[
int
]
|
None
,
prompt_embeds
:
torch
.
Tensor
|
None
,
...
...
vllm/utils/argparse_utils.py
View file @
6ebffafb
...
...
@@ -10,37 +10,21 @@ from argparse import (
ArgumentDefaultsHelpFormatter
,
ArgumentParser
,
ArgumentTypeError
,
Namespace
,
RawDescriptionHelpFormatter
,
_ArgumentGroup
,
)
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
import
regex
as
re
import
yaml
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
argparse
import
Namespace
else
:
Namespace
=
object
logger
=
init_logger
(
__name__
)
class
StoreBoolean
(
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
if
values
.
lower
()
==
"true"
:
setattr
(
namespace
,
self
.
dest
,
True
)
elif
values
.
lower
()
==
"false"
:
setattr
(
namespace
,
self
.
dest
,
False
)
else
:
raise
ValueError
(
f
"Invalid boolean value:
{
values
}
. Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
ArgumentDefaultsHelpFormatter
,
RawDescriptionHelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
...
...
@@ -487,12 +471,8 @@ class FlexibleArgumentParser(ArgumentParser):
)
raise
ex
store_boolean_arguments
=
[
action
.
dest
for
action
in
self
.
_actions
if
isinstance
(
action
,
StoreBoolean
)
]
for
key
,
value
in
config
.
items
():
if
isinstance
(
value
,
bool
)
and
key
not
in
store_boolean_arguments
:
if
isinstance
(
value
,
bool
):
if
value
:
processed_args
.
append
(
"--"
+
key
)
elif
isinstance
(
value
,
list
):
...
...
vllm/utils/import_utils.py
View file @
6ebffafb
...
...
@@ -19,6 +19,49 @@ import regex as re
from
typing_extensions
import
Never
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def
init_cached_hf_modules
()
->
None
:
"""
Lazy initialization of the Hugging Face modules.
"""
from
transformers.dynamic_module_utils
import
init_hf_modules
init_hf_modules
()
def
import_pynvml
():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
It is so confusing that many packages in the community use the
unofficial one by mistake, and we have to handle this case.
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
one, and it will cause errors, see the issue
https://github.com/vllm-project/vllm/issues/12847 for example.
After all the troubles, we decide to copy the official `pynvml`
module to our codebase, and use it directly.
"""
import
vllm.third_party.pynvml
as
pynvml
return
pynvml
def
import_from_path
(
module_name
:
str
,
file_path
:
str
|
os
.
PathLike
):
"""
Import a Python file according to its file path.
...
...
vllm/utils/system_utils.py
View file @
6ebffafb
...
...
@@ -4,19 +4,21 @@
from
__future__
import
annotations
import
contextlib
import
multiprocessing
import
os
import
signal
import
sys
from
collections.abc
import
Callable
,
Iterator
from
pathlib
import
Path
from
typing
import
TextIO
try
:
import
setproctitle
except
ImportError
:
setproctitle
=
None
# type: ignore[assignment]
import
psutil
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.ray.lazy_utils
import
is_in_ray_actor
from
.platform_utils
import
cuda_is_initialized
,
xpu_is_initialized
logger
=
init_logger
(
__name__
)
...
...
@@ -75,14 +77,66 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
# Process management utilities
def
_maybe_force_spawn
():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
==
"spawn"
:
return
reasons
=
[]
if
is_in_ray_actor
():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import
ray
os
.
environ
[
"RAY_ADDRESS"
]
=
ray
.
get_runtime_context
().
gcs_address
reasons
.
append
(
"In a Ray actor and can only be spawned"
)
if
cuda_is_initialized
():
reasons
.
append
(
"CUDA is initialized"
)
elif
xpu_is_initialized
():
reasons
.
append
(
"XPU is initialized"
)
if
reasons
:
logger
.
warning
(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reasons: %s"
,
"; "
.
join
(
reasons
),
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
def
set_process_title
(
name
:
str
,
suffix
:
str
=
""
,
prefix
:
str
=
envs
.
VLLM_PROCESS_NAME_PREFIX
name
:
str
,
suffix
:
str
=
""
,
prefix
:
str
=
envs
.
VLLM_PROCESS_NAME_PREFIX
,
)
->
None
:
"""Set the current process title with optional suffix."""
if
setproctitle
is
None
:
try
:
import
setproctitle
except
ImportError
:
return
if
suffix
:
name
=
f
"
{
name
}
_
{
suffix
}
"
setproctitle
.
setproctitle
(
f
"
{
prefix
}
::
{
name
}
"
)
...
...
@@ -114,10 +168,62 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def
decorate_logs
(
process_name
:
str
|
None
=
None
)
->
None
:
"""Decorate stdout/stderr with process name and PID prefix."""
from
vllm.utils
import
get_mp_context
if
process_name
is
None
:
process_name
=
get_mp_context
().
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
def
kill_process_tree
(
pid
:
int
):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try
:
parent
=
psutil
.
Process
(
pid
)
except
psutil
.
NoSuchProcess
:
return
# Get all children recursively
children
=
parent
.
children
(
recursive
=
True
)
# Send SIGKILL to all children first
for
child
in
children
:
with
contextlib
.
suppress
(
ProcessLookupError
):
os
.
kill
(
child
.
pid
,
signal
.
SIGKILL
)
# Finally kill the parent
with
contextlib
.
suppress
(
ProcessLookupError
):
os
.
kill
(
pid
,
signal
.
SIGKILL
)
# Resource utilities
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
def
set_ulimit
(
target_soft_limit
:
int
=
65535
):
if
sys
.
platform
.
startswith
(
"win"
):
logger
.
info
(
"Windows detected, skipping ulimit adjustment."
)
return
import
resource
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
logger
.
warning
(
"Found ulimit of %s and failed to automatically increase "
"with error %s. This can cause fd limit errors like "
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n"
,
current_soft
,
e
,
)
vllm/v1/engine/coordinator.py
View file @
6ebffafb
...
...
@@ -10,9 +10,8 @@ import zmq
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_mp_context
from
vllm.utils.network_utils
import
make_zmq_socket
from
vllm.utils.system_utils
import
set_process_title
from
vllm.utils.system_utils
import
get_mp_context
,
set_process_title
from
vllm.v1.engine
import
EngineCoreOutputs
,
EngineCoreRequestType
from
vllm.v1.serial_utils
import
MsgpackDecoder
from
vllm.v1.utils
import
get_engine_client_zmq_addr
,
shutdown
...
...
vllm/v1/engine/utils.py
View file @
6ebffafb
...
...
@@ -20,8 +20,8 @@ from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.ray.ray_env
import
get_env_vars_to_copy
from
vllm.utils
import
get_mp_context
from
vllm.utils.network_utils
import
get_open_zmq_ipc_path
,
zmq_socket_ctx
from
vllm.utils.system_utils
import
get_mp_context
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.executor
import
Executor
from
vllm.v1.utils
import
get_engine_client_zmq_addr
,
shutdown
...
...
vllm/v1/executor/multiproc_executor.py
View file @
6ebffafb
...
...
@@ -35,13 +35,17 @@ from vllm.distributed.parallel_state import (
)
from
vllm.envs
import
enable_envs_cache
from
vllm.logger
import
init_logger
from
vllm.utils
import
_maybe_force_spawn
,
get_mp_context
from
vllm.utils.network_utils
import
(
get_distributed_init_method
,
get_loopback_ip
,
get_open_port
,
)
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.utils.system_utils
import
(
_maybe_force_spawn
,
decorate_logs
,
get_mp_context
,
set_process_title
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.executor.abstract
import
Executor
,
FailureCallback
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
,
DraftTokenIds
,
ModelRunnerOutput
...
...
vllm/v1/executor/uniproc_executor.py
View file @
6ebffafb
...
...
@@ -12,11 +12,11 @@ import torch.distributed as dist
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
run_method
from
vllm.utils.network_utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.v1.serial_utils
import
run_method
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/serial_utils.py
View file @
6ebffafb
...
...
@@ -5,6 +5,7 @@ import dataclasses
import
importlib
import
pickle
from
collections.abc
import
Callable
,
Sequence
from
functools
import
partial
from
inspect
import
isclass
from
types
import
FunctionType
from
typing
import
Any
,
TypeAlias
...
...
@@ -429,3 +430,30 @@ class MsgpackDecoder:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
def
run_method
(
obj
:
Any
,
method
:
str
|
bytes
|
Callable
,
args
:
tuple
[
Any
,
...],
kwargs
:
dict
[
str
,
Any
],
)
->
Any
:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if
isinstance
(
method
,
bytes
):
func
=
partial
(
cloudpickle
.
loads
(
method
),
obj
)
elif
isinstance
(
method
,
str
):
try
:
func
=
getattr
(
obj
,
method
)
except
AttributeError
:
raise
NotImplementedError
(
f
"Method
{
method
!
r
}
is not implemented."
)
from
None
else
:
func
=
partial
(
method
,
obj
)
# type: ignore
return
func
(
*
args
,
**
kwargs
)
vllm/v1/utils.py
View file @
6ebffafb
...
...
@@ -25,8 +25,8 @@ from torch.autograd.profiler import record_function
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
,
usage_message
from
vllm.utils
import
kill_process_tree
from
vllm.utils.network_utils
import
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
from
vllm.utils.system_utils
import
kill_process_tree
if
TYPE_CHECKING
:
import
numpy
as
np
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment