Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0818808
Unverified
Commit
e0818808
authored
May 16, 2024
by
youkaichao
Committed by
GitHub
May 16, 2024
Browse files
[Core][Distributed] remove graph mode function (#4818)
parent
b5853f99
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
54 deletions
+63
-54
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+3
-2
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+2
-2
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+40
-30
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-20
No files found.
tests/distributed/test_custom_all_reduce.py
View file @
e0818808
...
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for
sz
in
test_sizes
:
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,
),
16
,
(
sz
,
),
...
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
for
i
in
range
(
num_communication
):
for
i
in
range
(
num_communication
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# the input buffer is immediately modified to test
...
...
tests/distributed/test_pynccl.py
View file @
e0818808
...
@@ -5,7 +5,7 @@ import pytest
...
@@ -5,7 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.distributed.communication_op
import
(
# noqa
from
vllm.distributed.communication_op
import
(
# noqa
graph_
mod
e
,
tensor_model_parallel_all_reduce
)
graph_
captur
e
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
ensure_model_parallel_initialized
(
2
,
2
)
ensure_model_parallel_initialized
(
2
,
2
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
graph_
mod
e
():
with
graph_
captur
e
():
# two tp groups can communicate independently
# two tp groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
...
...
vllm/distributed/communication_op.py
View file @
e0818808
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -13,8 +14,31 @@ from .parallel_state import (get_cpu_world_group,
...
@@ -13,8 +14,31 @@ from .parallel_state import (get_cpu_world_group,
get_tp_pynccl_communicator
)
get_tp_pynccl_communicator
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
@
contextmanager
@
contextmanager
def
graph_mode
():
def
graph_capture
():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
ca_comm
=
get_tp_ca_communicator
()
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# allreduce \ Mode | Eager | Graph |
...
@@ -23,8 +47,8 @@ def graph_mode():
...
@@ -23,8 +47,8 @@ def graph_mode():
# PyNccl | disabled| enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
# torch.distributed | enabled | disabled|
#
#
# Note that custom allreduce will have a runtime check, if the tensor
size
# Note that custom allreduce will have a runtime check, if the tensor
# is too large, it will fallback to the next available option.
#
size
is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
...
@@ -32,26 +56,12 @@ def graph_mode():
...
@@ -32,26 +56,12 @@ def graph_mode():
# to PyTorch or pynccl if it is disabled or not supported.
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm
=
get_tp_pynccl_communicator
()
pynccl_comm
=
get_tp_pynccl_communicator
()
if
pynccl_comm
is
None
:
if
pynccl_comm
is
None
:
context
=
nullcontext
()
maybe_pynccl_
context
=
nullcontext
()
else
:
else
:
context
=
pynccl_comm
.
change_state
(
enable
=
True
,
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
stream
=
torch
.
cuda
.
current_stream
())
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
context
:
with
maybe_pynccl_context
:
yield
yield
graph_capture_context
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm
=
get_tp_ca_communicator
()
context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
context
:
yield
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/worker/model_runner.py
View file @
e0818808
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
graph_capture
,
graph_mode
from
vllm.distributed.communication_op
import
graph_capture
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -841,7 +841,7 @@ class ModelRunner:
...
@@ -841,7 +841,7 @@ class ModelRunner:
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
...
@@ -877,6 +877,7 @@ class ModelRunner:
...
@@ -877,6 +877,7 @@ class ModelRunner:
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
memory_pool
=
self
.
graph_memory_pool
,
stream
=
graph_capture_context
.
stream
,
)
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
batch_size
]
=
graph_runner
self
.
graph_runners
[
batch_size
]
=
graph_runner
...
@@ -921,14 +922,14 @@ class CUDAGraphRunner:
...
@@ -921,14 +922,14 @@ class CUDAGraphRunner:
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
memory_pool
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
assert
self
.
_graph
is
None
assert
self
.
_graph
is
None
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
with
graph_mode
():
self
.
model
(
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
@@ -939,11 +940,8 @@ class CUDAGraphRunner:
...
@@ -939,11 +940,8 @@ class CUDAGraphRunner:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Capture the graph.
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
with
graph_mode
():
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment