Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
25e86b6a
Unverified
Commit
25e86b6a
authored
Feb 14, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 14, 2024
Browse files
Don't use cupy NCCL for AMD backends (#2855)
parent
4efbac6d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
7 deletions
+23
-7
vllm/model_executor/parallel_utils/custom_all_reduce.py
vllm/model_executor/parallel_utils/custom_all_reduce.py
+4
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+16
-6
vllm/worker/worker.py
vllm/worker/worker.py
+3
-1
No files found.
vllm/model_executor/parallel_utils/custom_all_reduce.py
View file @
25e86b6a
...
@@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
...
@@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
return
_CA_HANDLE
return
_CA_HANDLE
def
is_initialized
()
->
bool
:
return
_CA_HANDLE
is
not
None
@
contextmanager
@
contextmanager
def
capture
():
def
capture
():
try
:
try
:
...
...
vllm/worker/model_runner.py
View file @
25e86b6a
import
contextlib
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
...
@@ -9,9 +10,9 @@ from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
...
@@ -9,9 +10,9 @@ from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils
import
cupy_utils
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.cupy_utils
import
get_nccl_backend
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
with_cupy_nccl_for_all_reduce
)
with_cupy_nccl_for_all_reduce
)
from
vllm.model_executor.parallel_utils
import
custom_all_reduce
from
vllm.model_executor.parallel_utils
import
custom_all_reduce
...
@@ -659,7 +660,7 @@ class ModelRunner:
...
@@ -659,7 +660,7 @@ class ModelRunner:
def
capture_model
(
self
,
kv_caches
:
List
[
KVCache
])
->
None
:
def
capture_model
(
self
,
kv_caches
:
List
[
KVCache
])
->
None
:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
# deleted before the CUDA graphs.
self
.
cupy_nccl_backend
=
get_nccl_backend
()
self
.
cupy_nccl_backend
=
cupy_utils
.
get_nccl_backend
()
assert
not
self
.
model_config
.
enforce_eager
assert
not
self
.
model_config
.
enforce_eager
logger
.
info
(
"Capturing the model for CUDA graphs. This may lead to "
logger
.
info
(
"Capturing the model for CUDA graphs. This may lead to "
...
@@ -689,8 +690,6 @@ class ModelRunner:
...
@@ -689,8 +690,6 @@ class ModelRunner:
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
...
@@ -698,6 +697,8 @@ class ModelRunner:
...
@@ -698,6 +697,8 @@ class ModelRunner:
# We always prioritize using custom all-reduce kernel but fall back
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with
custom_all_reduce
.
capture
():
with
custom_all_reduce
.
capture
():
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy input_metadata.
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
...
@@ -765,7 +766,7 @@ class CUDAGraphRunner:
...
@@ -765,7 +766,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
with
with
_cupy_nccl
_for_all_reduce
():
with
_maybe
_cupy_nccl
():
self
.
model
(
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
@@ -779,7 +780,7 @@ class CUDAGraphRunner:
...
@@ -779,7 +780,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
torch
.
cuda
.
graph
(
self
.
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
with
_cupy_nccl
_for_all_reduce
():
with
_maybe
_cupy_nccl
():
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
@@ -830,6 +831,15 @@ class CUDAGraphRunner:
...
@@ -830,6 +831,15 @@ class CUDAGraphRunner:
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
@
contextlib
.
contextmanager
def
_maybe_cupy_nccl
():
if
cupy_utils
.
is_initialized
()
and
not
custom_all_reduce
.
is_initialized
():
with
with_cupy_nccl_for_all_reduce
():
yield
else
:
yield
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
assert
len
(
x
)
<=
max_len
assert
len
(
x
)
<=
max_len
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
...
...
vllm/worker/worker.py
View file @
25e86b6a
...
@@ -19,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...
@@ -19,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
is_hip
class
Worker
:
class
Worker
:
...
@@ -268,7 +269,8 @@ def init_distributed_environment(
...
@@ -268,7 +269,8 @@ def init_distributed_environment(
"cupy.distributed is already initialized but the cupy world "
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
"size does not match parallel_config.world_size "
f
"(
{
cupy_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
f
"(
{
cupy_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
parallel_config
.
world_size
>
1
and
cupy_port
is
not
None
:
elif
(
parallel_config
.
world_size
>
1
and
cupy_port
is
not
None
and
not
is_hip
()):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# is 1.
# TODO(woosuk): Support multi-node connection.
# TODO(woosuk): Support multi-node connection.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment