Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c2bdb83
Unverified
Commit
7c2bdb83
authored
Oct 27, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 27, 2025
Browse files
[Misc] Clean up utils (#27552)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
9932ed6a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
41 additions
and
258 deletions
+41
-258
docs/mkdocs/hooks/generate_argparse.py
docs/mkdocs/hooks/generate_argparse.py
+3
-1
tests/utils.py
tests/utils.py
+1
-3
tests/utils_/test_argparse_utils.py
tests/utils_/test_argparse_utils.py
+2
-91
tests/utils_/test_serial_utils.py
tests/utils_/test_serial_utils.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
...ted/kv_transfer/kv_connector/v1/decode_bench_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+2
-1
vllm/entrypoints/anthropic/api_server.py
vllm/entrypoints/anthropic/api_server.py
+2
-1
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+1
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+2
-1
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+1
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+1
-1
vllm/utils/__init__.py
vllm/utils/__init__.py
+24
-155
No files found.
docs/mkdocs/hooks/generate_argparse.py
View file @
7c2bdb83
...
@@ -65,7 +65,9 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
...
@@ -65,7 +65,9 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
CompleteCommand
=
auto_mock
(
"vllm.entrypoints.cli.openai"
,
"CompleteCommand"
)
CompleteCommand
=
auto_mock
(
"vllm.entrypoints.cli.openai"
,
"CompleteCommand"
)
cli_args
=
auto_mock
(
"vllm.entrypoints.openai"
,
"cli_args"
)
cli_args
=
auto_mock
(
"vllm.entrypoints.openai"
,
"cli_args"
)
run_batch
=
auto_mock
(
"vllm.entrypoints.openai"
,
"run_batch"
)
run_batch
=
auto_mock
(
"vllm.entrypoints.openai"
,
"run_batch"
)
FlexibleArgumentParser
=
auto_mock
(
"vllm.utils"
,
"FlexibleArgumentParser"
)
FlexibleArgumentParser
=
auto_mock
(
"vllm.utils.argparse_utils"
,
"FlexibleArgumentParser"
)
class
MarkdownFormatter
(
HelpFormatter
):
class
MarkdownFormatter
(
HelpFormatter
):
...
...
tests/utils.py
View file @
7c2bdb83
...
@@ -45,9 +45,7 @@ from vllm.entrypoints.cli.serve import ServeSubcommand
...
@@ -45,9 +45,7 @@ from vllm.entrypoints.cli.serve import ServeSubcommand
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
(
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
FlexibleArgumentParser
,
)
from
vllm.utils.mem_constants
import
GB_bytes
from
vllm.utils.mem_constants
import
GB_bytes
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
...
...
tests/utils_/test_utils.py
→
tests/utils_/test_
argparse_
utils.py
View file @
7c2bdb83
...
@@ -4,23 +4,15 @@
...
@@ -4,23 +4,15 @@
import
json
import
json
import
os
import
os
import
tempfile
from
pathlib
import
Path
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
yaml
import
yaml
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.transformers_utils.detokenizer_utils
import
convert_ids_list_to_tokens
from
vllm.transformers_utils.detokenizer_utils
import
convert_ids_list_to_tokens
from
vllm.utils
import
(
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
FlexibleArgumentParser
,
from
..utils
import
flat_product
bind_kv_cache
,
)
from
..utils
import
create_new_process_for_each_test
,
flat_product
# Tests for FlexibleArgumentParser
# Tests for FlexibleArgumentParser
...
@@ -256,87 +248,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
...
@@ -256,87 +248,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
assert
"-O.mode"
in
caplog_vllm
.
text
assert
"-O.mode"
in
caplog_vllm
.
text
def
test_bind_kv_cache
():
from
vllm.attention
import
Attention
ctx
=
{
"layers.0.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.1.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.2.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.3.self_attn"
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
"layers.0.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
"layers.1.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
"layers.2.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
2
]
assert
ctx
[
"layers.3.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
3
]
def
test_bind_kv_cache_kv_sharing
():
from
vllm.attention
import
Attention
ctx
=
{
"layers.0.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.1.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.2.self_attn"
:
Attention
(
32
,
128
,
0.1
),
"layers.3.self_attn"
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
]
shared_kv_cache_layers
=
{
"layers.2.self_attn"
:
"layers.1.self_attn"
,
"layers.3.self_attn"
:
"layers.0.self_attn"
,
}
bind_kv_cache
(
ctx
,
[
kv_cache
],
shared_kv_cache_layers
)
assert
ctx
[
"layers.0.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
"layers.1.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
"layers.2.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
"layers.3.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
0
]
def
test_bind_kv_cache_non_attention
():
from
vllm.attention
import
Attention
# example from Jamba PP=2
ctx
=
{
"model.layers.20.attn"
:
Attention
(
32
,
128
,
0.1
),
"model.layers.28.attn"
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,)),
torch
.
zeros
((
1
,)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
"model.layers.20.attn"
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
"model.layers.28.attn"
].
kv_cache
[
0
]
is
kv_cache
[
1
]
def
test_bind_kv_cache_pp
():
with
patch
(
"vllm.utils.torch_utils.cuda_device_count_stateless"
,
lambda
:
2
):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
2
))
with
set_current_vllm_config
(
cfg
):
from
vllm.attention
import
Attention
ctx
=
{
"layers.0.self_attn"
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[[
torch
.
zeros
((
1
,))],
[
torch
.
zeros
((
1
,))]]
bind_kv_cache
(
ctx
,
kv_cache
)
assert
ctx
[
"layers.0.self_attn"
].
kv_cache
[
0
]
is
kv_cache
[
0
][
0
]
assert
ctx
[
"layers.0.self_attn"
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
def
test_model_specification
(
def
test_model_specification
(
parser_with_config
,
cli_config_file
,
cli_config_file_with_model
parser_with_config
,
cli_config_file
,
cli_config_file_with_model
):
):
...
...
tests/utils_/test_serial_utils.py
View file @
7c2bdb83
...
@@ -14,7 +14,7 @@ from vllm.utils.serial_utils import (
...
@@ -14,7 +14,7 @@ from vllm.utils.serial_utils import (
@
pytest
.
mark
.
parametrize
(
"endianness"
,
ENDIANNESS
)
@
pytest
.
mark
.
parametrize
(
"endianness"
,
ENDIANNESS
)
@
pytest
.
mark
.
parametrize
(
"embed_dtype"
,
EMBED_DTYPE_TO_TORCH_DTYPE
.
keys
())
@
pytest
.
mark
.
parametrize
(
"embed_dtype"
,
EMBED_DTYPE_TO_TORCH_DTYPE
.
keys
())
@
torch
.
inference_mode
@
torch
.
inference_mode
()
def
test_encode_and_decode
(
embed_dtype
:
str
,
endianness
:
str
):
def
test_encode_and_decode
(
embed_dtype
:
str
,
endianness
:
str
):
for
i
in
range
(
10
):
for
i
in
range
(
10
):
tensor
=
torch
.
rand
(
2
,
3
,
5
,
7
,
11
,
13
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
tensor
=
torch
.
rand
(
2
,
3
,
5
,
7
,
11
,
13
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
View file @
7c2bdb83
...
@@ -42,7 +42,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
...
@@ -42,7 +42,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
)
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorMetadata
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorMetadata
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
.math_utils
import
cdiv
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
View file @
7c2bdb83
...
@@ -44,7 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils impo
...
@@ -44,7 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils impo
)
)
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
cdiv
,
get_kv_cache_torch_dtype
from
vllm.utils
import
get_kv_cache_torch_dtype
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
vllm/entrypoints/anthropic/api_server.py
View file @
7c2bdb83
...
@@ -51,7 +51,8 @@ from vllm.entrypoints.utils import (
...
@@ -51,7 +51,8 @@ from vllm.entrypoints.utils import (
with_cancellation
,
with_cancellation
,
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
,
set_ulimit
from
vllm.utils
import
set_ulimit
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
vllm/entrypoints/cli/serve.py
View file @
7c2bdb83
...
@@ -18,7 +18,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
...
@@ -18,7 +18,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
from
vllm.entrypoints.utils
import
VLLM_SUBCMD_PARSER_EPILOG
from
vllm.entrypoints.utils
import
VLLM_SUBCMD_PARSER_EPILOG
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
get_tcp_uri
from
vllm.utils.network_utils
import
get_tcp_uri
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCoreProc
...
...
vllm/entrypoints/openai/api_server.py
View file @
7c2bdb83
...
@@ -108,7 +108,8 @@ from vllm.entrypoints.utils import (
...
@@ -108,7 +108,8 @@ from vllm.entrypoints.utils import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
FlexibleArgumentParser
,
set_ulimit
from
vllm.utils
import
Device
,
set_ulimit
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.system_utils
import
decorate_logs
from
vllm.utils.system_utils
import
decorate_logs
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
7c2bdb83
...
@@ -13,7 +13,7 @@ import torch
...
@@ -13,7 +13,7 @@ import torch
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.triton_utils
import
HAS_TRITON
,
triton
from
vllm.triton_utils
import
HAS_TRITON
,
triton
from
vllm.utils
import
round_up
from
vllm.utils
.math_utils
import
round_up
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
(
from
vllm.lora.ops.triton_ops
import
(
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
7c2bdb83
...
@@ -48,9 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s
...
@@ -48,9 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
round_up
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/utils/__init__.py
View file @
7c2bdb83
...
@@ -12,12 +12,10 @@ import signal
...
@@ -12,12 +12,10 @@ import signal
import
sys
import
sys
import
tempfile
import
tempfile
import
threading
import
threading
import
traceback
import
uuid
import
uuid
import
warnings
import
warnings
import
weakref
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
functools
import
cache
,
partial
,
wraps
from
functools
import
partial
,
wraps
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
import
cloudpickle
import
cloudpickle
...
@@ -28,34 +26,6 @@ import vllm.envs as envs
...
@@ -28,34 +26,6 @@ import vllm.envs as envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.ray.lazy_utils
import
is_in_ray_actor
from
vllm.ray.lazy_utils
import
is_in_ray_actor
# Import utilities from specialized modules for backward compatibility
from
vllm.utils.argparse_utils
import
(
FlexibleArgumentParser
,
SortedHelpFormatter
,
StoreBoolean
,
)
from
vllm.utils.math_utils
import
(
cdiv
,
next_power_of_2
,
prev_power_of_2
,
round_down
,
round_up
,
)
from
vllm.utils.platform_utils
import
cuda_is_initialized
,
xpu_is_initialized
__all__
=
[
# Argparse utilities
"FlexibleArgumentParser"
,
"SortedHelpFormatter"
,
"StoreBoolean"
,
# Math utilities
"cdiv"
,
"next_power_of_2"
,
"prev_power_of_2"
,
"round_down"
,
"round_up"
,
]
_DEPRECATED_MAPPINGS
=
{
_DEPRECATED_MAPPINGS
=
{
"cprofile"
:
"profiling"
,
"cprofile"
:
"profiling"
,
"cprofile_context"
:
"profiling"
,
"cprofile_context"
:
"profiling"
,
...
@@ -84,12 +54,8 @@ def __dir__() -> list[str]:
...
@@ -84,12 +54,8 @@ def __dir__() -> list[str]:
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
argparse
import
Namespace
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
else
:
Namespace
=
object
ModelConfig
=
object
ModelConfig
=
object
VllmConfig
=
object
VllmConfig
=
object
...
@@ -149,35 +115,33 @@ class Counter:
...
@@ -149,35 +115,33 @@ class Counter:
self
.
counter
=
0
self
.
counter
=
0
def
random_uuid
()
->
str
:
class
AtomicCounter
:
return
str
(
uuid
.
uuid4
().
hex
)
"""An atomic, thread-safe counter"""
def
update_environment_variables
(
envs
:
dict
[
str
,
str
]):
def
__init__
(
self
,
initial
=
0
):
for
k
,
v
in
envs
.
items
():
"""Initialize a new atomic counter to given initial value"""
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
self
.
_value
=
initial
logger
.
warning
(
self
.
_lock
=
threading
.
Lock
()
"Overwriting environment variable %s from '%s' to '%s'"
,
k
,
os
.
environ
[
k
],
v
,
)
os
.
environ
[
k
]
=
v
def
inc
(
self
,
num
=
1
):
"""Atomically increment the counter by num and return the new value"""
with
self
.
_lock
:
self
.
_value
+=
num
return
self
.
_value
@
cache
def
dec
(
self
,
num
=
1
):
def
is_pin_memory_available
()
->
bool
:
"""Atomically decrement the counter by num and return the new value"""
from
vllm.platforms
import
current_platform
with
self
.
_lock
:
self
.
_value
-=
num
return
self
.
_value
return
current_platform
.
is_pin_memory_available
()
@
property
def
value
(
self
):
return
self
.
_value
@
cache
def
random_uuid
()
->
str
:
def
is_uva_available
()
->
bool
:
return
str
(
uuid
.
uuid4
().
hex
)
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return
is_pin_memory_available
()
# TODO: This function can be removed if transformer_modules classes are
# TODO: This function can be removed if transformer_modules classes are
...
@@ -212,47 +176,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
...
@@ -212,47 +176,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
enable_trace_function_call
(
log_path
)
enable_trace_function_call
(
log_path
)
def
weak_bind
(
bound_method
:
Callable
[...,
Any
],
)
->
Callable
[...,
None
]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
instance is collected."""
ref
=
weakref
.
ref
(
bound_method
.
__self__
)
# type: ignore[attr-defined]
unbound
=
bound_method
.
__func__
# type: ignore[attr-defined]
def
weak_bound
(
*
args
,
**
kwargs
)
->
None
:
if
inst
:
=
ref
():
unbound
(
inst
,
*
args
,
**
kwargs
)
return
weak_bound
class
AtomicCounter
:
"""An atomic, thread-safe counter"""
def
__init__
(
self
,
initial
=
0
):
"""Initialize a new atomic counter to given initial value"""
self
.
_value
=
initial
self
.
_lock
=
threading
.
Lock
()
def
inc
(
self
,
num
=
1
):
"""Atomically increment the counter by num and return the new value"""
with
self
.
_lock
:
self
.
_value
+=
num
return
self
.
_value
def
dec
(
self
,
num
=
1
):
"""Atomically decrement the counter by num and return the new value"""
with
self
.
_lock
:
self
.
_value
-=
num
return
self
.
_value
@
property
def
value
(
self
):
return
self
.
_value
def
kill_process_tree
(
pid
:
int
):
def
kill_process_tree
(
pid
:
int
):
"""
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Kills all descendant processes of the given pid by sending SIGKILL.
...
@@ -303,13 +226,6 @@ def set_ulimit(target_soft_limit=65535):
...
@@ -303,13 +226,6 @@ def set_ulimit(target_soft_limit=65535):
)
)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
err_str
=
""
.
join
(
traceback
.
format_exception
(
etype
,
value
,
tb
))
return
err_str
def
_maybe_force_spawn
():
def
_maybe_force_spawn
():
"""Check if we need to force the use of the `spawn` multiprocessing start
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
method.
...
@@ -327,6 +243,8 @@ def _maybe_force_spawn():
...
@@ -327,6 +243,8 @@ def _maybe_force_spawn():
os
.
environ
[
"RAY_ADDRESS"
]
=
ray
.
get_runtime_context
().
gcs_address
os
.
environ
[
"RAY_ADDRESS"
]
=
ray
.
get_runtime_context
().
gcs_address
reasons
.
append
(
"In a Ray actor and can only be spawned"
)
reasons
.
append
(
"In a Ray actor and can only be spawned"
)
from
.platform_utils
import
cuda_is_initialized
,
xpu_is_initialized
if
cuda_is_initialized
():
if
cuda_is_initialized
():
reasons
.
append
(
"CUDA is initialized"
)
reasons
.
append
(
"CUDA is initialized"
)
elif
xpu_is_initialized
():
elif
xpu_is_initialized
():
...
@@ -356,55 +274,6 @@ def get_mp_context():
...
@@ -356,55 +274,6 @@ def get_mp_context():
return
multiprocessing
.
get_context
(
mp_method
)
return
multiprocessing
.
get_context
(
mp_method
)
def
bind_kv_cache
(
ctx
:
dict
[
str
,
Any
],
kv_cache
:
list
[
list
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
shared_kv_cache_layers
:
dict
[
str
,
str
]
|
None
=
None
,
)
->
None
:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if
shared_kv_cache_layers
is
None
:
shared_kv_cache_layers
=
{}
from
vllm.attention
import
AttentionType
from
vllm.model_executor.models.utils
import
extract_layer_index
layer_need_kv_cache
=
[
layer_name
for
layer_name
in
ctx
if
(
hasattr
(
ctx
[
layer_name
],
"attn_type"
)
and
ctx
[
layer_name
].
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
)
)
and
ctx
[
layer_name
].
kv_sharing_target_layer_name
is
None
]
layer_index_sorted
=
sorted
(
set
(
extract_layer_index
(
layer_name
)
for
layer_name
in
layer_need_kv_cache
)
)
for
layer_name
in
layer_need_kv_cache
:
kv_cache_idx
=
layer_index_sorted
.
index
(
extract_layer_index
(
layer_name
))
forward_ctx
=
ctx
[
layer_name
]
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
if
shared_kv_cache_layers
is
not
None
:
for
layer_name
,
target_layer_name
in
shared_kv_cache_layers
.
items
():
assert
extract_layer_index
(
target_layer_name
)
<
extract_layer_index
(
layer_name
),
"v0 doesn't support interleaving kv sharing"
ctx
[
layer_name
].
kv_cache
=
ctx
[
target_layer_name
].
kv_cache
def
run_method
(
def
run_method
(
obj
:
Any
,
obj
:
Any
,
method
:
str
|
bytes
|
Callable
,
method
:
str
|
bytes
|
Callable
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment