Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b835032
Unverified
Commit
2b835032
authored
Jan 18, 2025
by
youkaichao
Committed by
GitHub
Jan 18, 2025
Browse files
[misc] fix cross-node TP (#12166)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
7b98a65a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
24 deletions
+36
-24
vllm/executor/mp_distributed_executor.py
vllm/executor/mp_distributed_executor.py
+36
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+0
-22
No files found.
vllm/executor/mp_distributed_executor.py
View file @
2b835032
import
asyncio
import
asyncio
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
...
@@ -10,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (
...
@@ -10,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
get_ip
,
get_open_port
,
make_async
,
run_method
)
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
,
run_method
,
update_environment_variables
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -22,7 +24,39 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
...
@@ -22,7 +24,39 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
uses_ray
:
bool
=
False
uses_ray
:
bool
=
False
def
_check_cuda
(
self
)
->
None
:
"""Check that the number of GPUs is sufficient for the parallel
configuration. Separate from _init_executor to reduce the number of
indented blocks.
"""
parallel_config
=
self
.
parallel_config
world_size
=
parallel_config
.
world_size
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
if
tensor_parallel_size
>
cuda_device_count
:
raise
RuntimeError
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
if
world_size
>
cuda_device_count
:
raise
RuntimeError
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
self
.
_check_cuda
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
world_size
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
...
...
vllm/platforms/cuda.py
View file @
2b835032
...
@@ -139,28 +139,6 @@ class CudaPlatformBase(Platform):
...
@@ -139,28 +139,6 @@ class CudaPlatformBase(Platform):
else
:
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
world_size
=
parallel_config
.
world_size
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
from
vllm.utils
import
(
cuda_device_count_stateless
,
update_environment_variables
)
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
if
cache_config
and
cache_config
.
block_size
is
None
:
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
cache_config
.
block_size
=
16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment