Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1b62745b
Unverified
Commit
1b62745b
authored
Dec 07, 2024
by
youkaichao
Committed by
GitHub
Dec 07, 2024
Browse files
[core][executor] simplify instance id (#10976)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
78029b34
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
22 additions
and
55 deletions
+22
-55
vllm/config.py
vllm/config.py
+6
-1
vllm/envs.py
vllm/envs.py
+0
-6
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-5
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+1
-4
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+1
-6
vllm/executor/ray_hpu_executor.py
vllm/executor/ray_hpu_executor.py
+1
-6
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+1
-5
vllm/executor/ray_xpu_executor.py
vllm/executor/ray_xpu_executor.py
+1
-5
vllm/utils.py
vllm/utils.py
+9
-16
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-1
No files found.
vllm/config.py
View file @
1b62745b
...
@@ -27,7 +27,8 @@ from vllm.transformers_utils.config import (
...
@@ -27,7 +27,8 @@ from vllm.transformers_utils.config import (
get_hf_text_config
,
get_pooling_config
,
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
print_warning_once
,
resolve_obj_by_qualname
)
print_warning_once
,
random_uuid
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -2408,6 +2409,7 @@ class VllmConfig:
...
@@ -2408,6 +2409,7 @@ class VllmConfig:
init
=
True
)
# type: ignore
init
=
True
)
# type: ignore
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
init
=
True
)
# type: ignore
instance_id
:
str
=
""
@
staticmethod
@
staticmethod
def
get_graph_batch_size
(
batch_size
:
int
)
->
int
:
def
get_graph_batch_size
(
batch_size
:
int
)
->
int
:
...
@@ -2573,6 +2575,9 @@ class VllmConfig:
...
@@ -2573,6 +2575,9 @@ class VllmConfig:
current_platform
.
check_and_update_config
(
self
)
current_platform
.
check_and_update_config
(
self
)
if
not
self
.
instance_id
:
self
.
instance_id
=
random_uuid
()[:
5
]
def
__str__
(
self
):
def
__str__
(
self
):
return
(
"model=%r, speculative_config=%r, tokenizer=%r, "
return
(
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
...
...
vllm/envs.py
View file @
1b62745b
...
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
...
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
VLLM_RPC_BASE_PATH
:
str
=
tempfile
.
gettempdir
()
VLLM_RPC_BASE_PATH
:
str
=
tempfile
.
gettempdir
()
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
...
@@ -175,11 +174,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -175,11 +174,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_MODELSCOPE"
:
"VLLM_USE_MODELSCOPE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
,
lambda
:
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
,
# Instance id represents an instance of the VLLM. All processes in the same
# instance should have the same instance id.
"VLLM_INSTANCE_ID"
:
lambda
:
os
.
environ
.
get
(
"VLLM_INSTANCE_ID"
,
None
),
# Interval in seconds to log a warning message when the ring buffer is full
# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL"
:
"VLLM_RINGBUFFER_WARNING_INTERVAL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_RINGBUFFER_WARNING_INTERVAL"
,
"60"
)),
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_RINGBUFFER_WARNING_INTERVAL"
,
"60"
)),
...
...
vllm/executor/cpu_executor.py
View file @
1b62745b
...
@@ -10,8 +10,7 @@ from vllm.lora.request import LoRARequest
...
@@ -10,8 +10,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_open_port
,
from
vllm.utils
import
get_distributed_init_method
,
get_open_port
,
make_async
get_vllm_instance_id
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -31,9 +30,6 @@ class CPUExecutor(ExecutorBase):
...
@@ -31,9 +30,6 @@ class CPUExecutor(ExecutorBase):
# Environment variables for CPU executor
# Environment variables for CPU executor
#
#
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
# Disable torch async compiling which won't work with daemonic processes
# Disable torch async compiling which won't work with daemonic processes
os
.
environ
[
"TORCHINDUCTOR_COMPILE_THREADS"
]
=
"1"
os
.
environ
[
"TORCHINDUCTOR_COMPILE_THREADS"
]
=
"1"
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
1b62745b
...
@@ -16,7 +16,7 @@ from vllm.sequence import ExecuteModelRequest
...
@@ -16,7 +16,7 @@ from vllm.sequence import ExecuteModelRequest
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
cuda_is_initialized
,
get_distributed_init_method
,
cuda_is_initialized
,
get_distributed_init_method
,
get_open_port
,
get_vllm_instance_id
,
make_async
,
get_open_port
,
make_async
,
update_environment_variables
)
update_environment_variables
)
if
HAS_TRITON
:
if
HAS_TRITON
:
...
@@ -37,9 +37,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -37,9 +37,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
world_size
=
self
.
parallel_config
.
world_size
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
# Disable torch async compiling which won't work with daemonic processes
# Disable torch async compiling which won't work with daemonic processes
os
.
environ
[
"TORCHINDUCTOR_COMPILE_THREADS"
]
=
"1"
os
.
environ
[
"TORCHINDUCTOR_COMPILE_THREADS"
]
=
"1"
...
...
vllm/executor/ray_gpu_executor.py
View file @
1b62745b
...
@@ -15,8 +15,7 @@ from vllm.logger import init_logger
...
@@ -15,8 +15,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
if
ray
is
not
None
:
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -220,14 +219,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -220,14 +219,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
" environment variable, make sure it is unique for"
" environment variable, make sure it is unique for"
" each node."
)
" each node."
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
# Set environment variables for the driver and workers.
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
all_args_to_update_environment_variables
=
[({
"CUDA_VISIBLE_DEVICES"
:
"CUDA_VISIBLE_DEVICES"
:
","
.
join
(
map
(
str
,
node_gpus
[
node_id
])),
","
.
join
(
map
(
str
,
node_gpus
[
node_id
])),
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
str
(
envs
.
VLLM_TRACE_FUNCTION
),
**
({
**
({
...
...
vllm/executor/ray_hpu_executor.py
View file @
1b62745b
...
@@ -15,8 +15,7 @@ from vllm.logger import init_logger
...
@@ -15,8 +15,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
if
ray
is
not
None
:
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -196,12 +195,8 @@ class RayHPUExecutor(DistributedGPUExecutor):
...
@@ -196,12 +195,8 @@ class RayHPUExecutor(DistributedGPUExecutor):
"environment variable, make sure it is unique for"
"environment variable, make sure it is unique for"
" each node."
)
" each node."
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
# Set environment variables for the driver and workers.
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
all_args_to_update_environment_variables
=
[({
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
str
(
envs
.
VLLM_TRACE_FUNCTION
),
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
...
...
vllm/executor/ray_tpu_executor.py
View file @
1b62745b
...
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
...
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
make_async
)
if
ray
is
not
None
:
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -144,12 +144,8 @@ class RayTPUExecutor(TPUExecutor):
...
@@ -144,12 +144,8 @@ class RayTPUExecutor(TPUExecutor):
for
i
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
):
for
i
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
):
node_workers
[
node_id
].
append
(
i
)
node_workers
[
node_id
].
append
(
i
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
# Set environment variables for the driver and workers.
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
all_args_to_update_environment_variables
=
[({
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
str
(
envs
.
VLLM_TRACE_FUNCTION
),
},
)
for
_
in
worker_node_and_gpu_ids
]
},
)
for
_
in
worker_node_and_gpu_ids
]
...
...
vllm/executor/ray_xpu_executor.py
View file @
1b62745b
...
@@ -5,7 +5,7 @@ import vllm.envs as envs
...
@@ -5,7 +5,7 @@ import vllm.envs as envs
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
,
RayGPUExecutorAsync
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
,
RayGPUExecutorAsync
from
vllm.executor.xpu_executor
import
XPUExecutor
from
vllm.executor.xpu_executor
import
XPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_vllm_instance_id
,
make_async
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -17,12 +17,8 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
...
@@ -17,12 +17,8 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
use_dummy_driver
=
True
)
use_dummy_driver
=
True
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
# Set environment variables for the driver and workers.
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
all_args_to_update_environment_variables
=
[({
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
str
(
envs
.
VLLM_TRACE_FUNCTION
),
},
)
for
(
_
,
_
)
in
worker_node_and_gpu_ids
]
},
)
for
(
_
,
_
)
in
worker_node_and_gpu_ids
]
...
...
vllm/utils.py
View file @
1b62745b
...
@@ -24,9 +24,9 @@ from collections import UserDict, defaultdict
...
@@ -24,9 +24,9 @@ from collections import UserDict, defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Hashable
,
List
,
Literal
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
Dict
,
Generic
,
Hashable
,
List
,
Literal
,
Optional
,
Type
,
TypeVar
,
Union
,
overload
)
OrderedDict
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
numpy
as
np
import
numpy
as
np
...
@@ -43,6 +43,9 @@ import vllm.envs as envs
...
@@ -43,6 +43,9 @@ import vllm.envs as envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Exception strings for non-implemented encoder/decoder scenarios
# Exception strings for non-implemented encoder/decoder scenarios
...
@@ -335,17 +338,6 @@ def random_uuid() -> str:
...
@@ -335,17 +338,6 @@ def random_uuid() -> str:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
@
lru_cache
(
maxsize
=
None
)
def
get_vllm_instance_id
()
->
str
:
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return
envs
.
VLLM_INSTANCE_ID
or
f
"vllm-instance-
{
random_uuid
()
}
"
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
in_wsl
()
->
bool
:
def
in_wsl
()
->
bool
:
# Reference: https://github.com/microsoft/WSL/issues/4071
# Reference: https://github.com/microsoft/WSL/issues/4071
...
@@ -997,7 +989,7 @@ def find_nccl_library() -> str:
...
@@ -997,7 +989,7 @@ def find_nccl_library() -> str:
return
so_file
return
so_file
def
enable_trace_function_call_for_thread
()
->
None
:
def
enable_trace_function_call_for_thread
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Set up function tracing for the current thread,
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
"""
...
@@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None:
...
@@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None:
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
get_vllm_instance_id
(),
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
f
"vllm-instance-
{
vllm_config
.
instance_id
}
"
,
filename
)
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
enable_trace_function_call
(
log_path
)
...
...
vllm/worker/worker_base.py
View file @
1b62745b
...
@@ -439,7 +439,7 @@ class WorkerWrapperBase:
...
@@ -439,7 +439,7 @@ class WorkerWrapperBase:
Here we inject some common logic before initializing the worker.
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
Arguments are passed to the worker class constructor.
"""
"""
enable_trace_function_call_for_thread
()
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
# see https://github.com/NVIDIA/nccl/issues/1234
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment