Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
310aca88
Unverified
Commit
310aca88
authored
Jan 09, 2025
by
youkaichao
Committed by
GitHub
Jan 09, 2025
Browse files
[perf]fix current stream (#11870)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
a732900e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
15 deletions
+46
-15
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+8
-7
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-4
vllm/utils.py
vllm/utils.py
+33
-0
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+4
-4
No files found.
vllm/distributed/device_communicators/pynccl.py
View file @
310aca88
...
...
@@ -10,6 +10,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclRedOpTypeEnum
,
ncclUniqueId
)
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.utils
import
current_stream
logger
=
init_logger
(
__name__
)
...
...
@@ -96,7 +97,7 @@ class PyNcclCommunicator:
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
...
...
@@ -119,7 +120,7 @@ class PyNcclCommunicator:
out_tensor
=
torch
.
empty_like
(
in_tensor
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
in_tensor
.
data_ptr
()),
buffer_type
(
out_tensor
.
data_ptr
()),
in_tensor
.
numel
(),
...
...
@@ -141,7 +142,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
...
...
@@ -162,7 +163,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
...
...
@@ -177,7 +178,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
...
...
@@ -189,7 +190,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
...
...
@@ -201,7 +202,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
if
src
==
self
.
rank
:
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
# NCCL requires the sender also to have a receive buffer
...
...
vllm/distributed/parallel_state.py
View file @
310aca88
...
...
@@ -357,10 +357,7 @@ class GroupCoordinator:
return
out
pynccl_comm
=
self
.
pynccl_comm
assert
pynccl_comm
is
not
None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out
=
pynccl_comm
.
all_reduce
(
input_
,
stream
=
torch
.
cuda
.
current_stream
())
out
=
pynccl_comm
.
all_reduce
(
input_
)
if
out
is
None
:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
...
...
vllm/utils.py
View file @
310aca88
...
...
@@ -944,6 +944,39 @@ def find_nccl_library() -> str:
return
so_file
prev_set_stream
=
torch
.
cuda
.
set_stream
_current_stream
=
None
def
_patched_set_stream
(
stream
:
torch
.
cuda
.
Stream
)
->
None
:
global
_current_stream
_current_stream
=
stream
prev_set_stream
(
stream
)
torch
.
cuda
.
set_stream
=
_patched_set_stream
def
current_stream
()
->
torch
.
cuda
.
Stream
:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
global
_current_stream
if
_current_stream
is
None
:
# when this function is called before any stream is set,
# we return the default stream.
_current_stream
=
torch
.
cuda
.
current_stream
()
return
_current_stream
def
enable_trace_function_call_for_thread
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
...
...
vllm/worker/multi_step_model_runner.py
View file @
310aca88
...
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
PyObjectCache
,
async_tensor_h2d
from
vllm.utils
import
PyObjectCache
,
async_tensor_h2d
,
current_stream
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
...
...
@@ -498,7 +498,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU
current_
stream
=
torch
.
cuda
.
current_stream
()
stream
=
current_stream
()
if
not
model_input
.
is_first_multi_step
:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
...
...
@@ -541,7 +541,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
num_steps
=
1
)
# record the event for the current step so that the next step can sync
model_input
.
record_step_event
(
current_
stream
)
model_input
.
record_step_event
(
stream
)
if
get_pp_group
().
is_last_rank
and
self
.
is_driver_worker
:
assert
isinstance
(
output
,
list
)
...
...
@@ -552,7 +552,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event
=
torch
.
cuda
.
Event
()
output_ready_event
.
record
(
current_
stream
)
output_ready_event
.
record
(
stream
)
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
output
[
0
].
sampled_token_ids_cpu
=
output
[
0
].
sampled_token_ids
.
cpu
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment