Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0818808
Unverified
Commit
e0818808
authored
May 16, 2024
by
youkaichao
Committed by
GitHub
May 16, 2024
Browse files
[Core][Distributed] remove graph mode function (#4818)
parent
b5853f99
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
54 deletions
+63
-54
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+3
-2
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+2
-2
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+40
-30
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-20
No files found.
tests/distributed/test_custom_all_reduce.py
View file @
e0818808
...
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for
sz
in
test_sizes
:
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,
),
16
,
(
sz
,
),
...
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
for
i
in
range
(
num_communication
):
for
i
in
range
(
num_communication
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# the input buffer is immediately modified to test
...
...
tests/distributed/test_pynccl.py
View file @
e0818808
...
@@ -5,7 +5,7 @@ import pytest
...
@@ -5,7 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.distributed.communication_op
import
(
# noqa
from
vllm.distributed.communication_op
import
(
# noqa
graph_
mod
e
,
tensor_model_parallel_all_reduce
)
graph_
captur
e
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
ensure_model_parallel_initialized
(
2
,
2
)
ensure_model_parallel_initialized
(
2
,
2
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
graph_
mod
e
():
with
graph_
captur
e
():
# two tp groups can communicate independently
# two tp groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
...
...
vllm/distributed/communication_op.py
View file @
e0818808
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
...
@@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
get_tp_pynccl_communicator
)
get_tp_pynccl_communicator
)
@
contextmanager
@
dataclass
def
graph_mode
():
class
GraphCaptureContext
:
# In graph mode, we have to be very careful about the collective
stream
:
torch
.
cuda
.
Stream
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm
=
get_tp_pynccl_communicator
()
if
pynccl_comm
is
None
:
context
=
nullcontext
()
else
:
context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
context
:
yield
@
contextmanager
@
contextmanager
def
graph_capture
():
def
graph_capture
():
"""
"""
`graph_capture` is a context manager which should
include
the code that
`graph_capture` is a context manager which should
surround
the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
some operations will be run after the graph is captured, before the graph
is replayed.
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
"""
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
ca_comm
=
get_tp_ca_communicator
()
ca_comm
=
get_tp_ca_communicator
()
context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
context
:
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
yield
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm
=
get_tp_pynccl_communicator
()
if
pynccl_comm
is
None
:
maybe_pynccl_context
=
nullcontext
()
else
:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
maybe_pynccl_context
:
yield
graph_capture_context
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/worker/model_runner.py
View file @
e0818808
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
graph_capture
,
graph_mode
from
vllm.distributed.communication_op
import
graph_capture
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -841,7 +841,7 @@ class ModelRunner:
...
@@ -841,7 +841,7 @@ class ModelRunner:
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
with
graph_capture
():
with
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
...
@@ -877,6 +877,7 @@ class ModelRunner:
...
@@ -877,6 +877,7 @@ class ModelRunner:
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
memory_pool
=
self
.
graph_memory_pool
,
stream
=
graph_capture_context
.
stream
,
)
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
batch_size
]
=
graph_runner
self
.
graph_runners
[
batch_size
]
=
graph_runner
...
@@ -921,15 +922,27 @@ class CUDAGraphRunner:
...
@@ -921,15 +922,27 @@ class CUDAGraphRunner:
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
memory_pool
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
assert
self
.
_graph
is
None
assert
self
.
_graph
is
None
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
with
graph_mode
():
self
.
model
(
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
# Capture the graph.
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
...
@@ -938,21 +951,6 @@ class CUDAGraphRunner:
...
@@ -938,21 +951,6 @@ class CUDAGraphRunner:
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
graph_mode
():
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
# Save the input and output buffers.
# Save the input and output buffers.
self
.
input_buffers
=
{
self
.
input_buffers
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment