Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0818808
Unverified
Commit
e0818808
authored
May 16, 2024
by
youkaichao
Committed by
GitHub
May 16, 2024
Browse files
[Core][Distributed] remove graph mode function (#4818)
parent
b5853f99
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
54 deletions
+63
-54
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+3
-2
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+2
-2
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+40
-30
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-20
No files found.
tests/distributed/test_custom_all_reduce.py
View file @
e0818808
...
...
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,
),
...
...
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device
=
torch
.
cuda
.
current_device
())
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
for
i
in
range
(
num_communication
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
...
...
tests/distributed/test_pynccl.py
View file @
e0818808
...
...
@@ -5,7 +5,7 @@ import pytest
import
torch
from
vllm.distributed.communication_op
import
(
# noqa
graph_
mod
e
,
tensor_model_parallel_all_reduce
)
graph_
captur
e
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
...
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
ensure_model_parallel_initialized
(
2
,
2
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
graph_
mod
e
():
with
graph_
captur
e
():
# two tp groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
...
...
vllm/distributed/communication_op.py
View file @
e0818808
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -13,8 +14,31 @@ from .parallel_state import (get_cpu_world_group,
get_tp_pynccl_communicator
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
@
contextmanager
def
graph_mode
():
def
graph_capture
():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
ca_comm
=
get_tp_ca_communicator
()
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
...
...
@@ -23,8 +47,8 @@ def graph_mode():
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
size
# is too large, it will fallback to the next available option.
# Note that custom allreduce will have a runtime check, if the tensor
#
size
is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
...
...
@@ -32,26 +56,12 @@ def graph_mode():
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm
=
get_tp_pynccl_communicator
()
if
pynccl_comm
is
None
:
context
=
nullcontext
()
maybe_pynccl_
context
=
nullcontext
()
else
:
context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
context
:
yield
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm
=
get_tp_ca_communicator
()
context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
context
:
yield
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
maybe_pynccl_context
:
yield
graph_capture_context
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/worker/model_runner.py
View file @
e0818808
...
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
graph_capture
,
graph_mode
from
vllm.distributed.communication_op
import
graph_capture
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
...
...
@@ -841,7 +841,7 @@ class ModelRunner:
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
...
...
@@ -877,6 +877,7 @@ class ModelRunner:
kv_caches
,
attn_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
stream
=
graph_capture_context
.
stream
,
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
batch_size
]
=
graph_runner
...
...
@@ -921,14 +922,14 @@ class CUDAGraphRunner:
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
memory_pool
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
)
->
None
:
assert
self
.
_graph
is
None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with
graph_mode
():
self
.
model
(
input_ids
,
positions
,
...
...
@@ -939,11 +940,8 @@ class CUDAGraphRunner:
torch
.
cuda
.
synchronize
()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
graph_mode
():
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment