Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8549c826
Unverified
Commit
8549c826
authored
Oct 27, 2024
by
youkaichao
Committed by
GitHub
Oct 27, 2024
Browse files
[core] cudagraph output with tensor weak reference (#9724)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
67a6882d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
28 deletions
+50
-28
csrc/ops.h
csrc/ops.h
+24
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+3
-0
vllm/utils.py
vllm/utils.py
+9
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+14
-28
No files found.
csrc/ops.h
View file @
8549c826
...
...
@@ -5,6 +5,30 @@
#include "core/scalar_type.hpp"
#include <vector>
torch
::
Tensor
weak_ref_tensor
(
torch
::
Tensor
&
tensor
)
{
// Ensure tensor is on CUDA
if
(
!
tensor
.
is_cuda
())
{
throw
std
::
runtime_error
(
"Tensor must be on CUDA device"
);
}
// Get the raw data pointer
void
*
data_ptr
=
tensor
.
data_ptr
();
// Get tensor sizes and strides
std
::
vector
<
int64_t
>
sizes
=
tensor
.
sizes
().
vec
();
std
::
vector
<
int64_t
>
strides
=
tensor
.
strides
().
vec
();
// Get tensor options (dtype, device)
auto
options
=
tensor
.
options
();
// Create a new tensor from the raw data pointer
auto
new_tensor
=
torch
::
from_blob
(
data_ptr
,
sizes
,
strides
,
options
);
return
new_tensor
;
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
...
csrc/torch_bindings.cpp
View file @
8549c826
...
...
@@ -18,6 +18,9 @@
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
...
...
vllm/utils.py
View file @
8549c826
...
...
@@ -1479,3 +1479,12 @@ class LazyDict(Mapping, Generic[T]):
def
__len__
(
self
):
return
len
(
self
.
_factory
)
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
return
torch
.
ops
.
_C
.
weak_ref_tensor
(
tensor
)
vllm/worker/model_runner.py
View file @
8549c826
...
...
@@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
(
DeviceMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
,
supports_dynamo
)
supports_dynamo
,
weak_ref_tensor
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
...
...
@@ -1426,12 +1426,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_or_intermediate_states
:
List
[
Optional
[
torch
.
Tensor
]]
=
[
None
]
*
self
.
parallel_config
.
pipeline_parallel_size
graph_batch_size
=
self
.
max_batchsize_to_capture
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
...
...
@@ -1474,12 +1468,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_tokens
[:
batch_size
],
"positions"
:
input_positions
[...,
:
batch_size
],
"hidden_or_intermediate_states"
:
hidden_or_intermediate_states
[
virtual_engine
]
# type: ignore
[:
batch_size
]
if
hidden_or_intermediate_states
[
virtual_engine
]
is
not
None
else
None
,
"intermediate_inputs"
:
intermediate_inputs
[:
batch_size
]
if
intermediate_inputs
is
not
None
else
None
,
...
...
@@ -1762,15 +1750,13 @@ class CUDAGraphRunner(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_or_intermediate_states
:
Optional
[
Union
[
IntermediateTensors
,
torch
.
Tensor
]],
intermediate_inputs
:
Optional
[
IntermediateTensors
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
memory_pool
:
Optional
[
Tuple
[
int
,
int
]],
stream
:
torch
.
cuda
.
Stream
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
):
assert
self
.
_graph
is
None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
...
...
@@ -1799,20 +1785,21 @@ class CUDAGraphRunner(nn.Module):
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
)
if
hidden_or_intermediate_states
is
not
None
:
if
get_pp_group
().
is_last_rank
:
hidden_or_intermediate_states
.
copy_
(
output_hidden_or_intermediate_states
)
else
:
for
key
in
hidden_or_intermediate_states
.
tensors
:
hidden_or_intermediate_states
[
key
].
copy_
(
output_hidden_or_intermediate_states
[
key
])
else
:
hidden_or_intermediate_states
=
(
if
isinstance
(
output_hidden_or_intermediate_states
,
torch
.
Tensor
):
hidden_or_intermediate_states
=
weak_ref_tensor
(
output_hidden_or_intermediate_states
)
elif
isinstance
(
output_hidden_or_intermediate_states
,
IntermediateTensors
):
hidden_or_intermediate_states
=
IntermediateTensors
(
tensors
=
{
key
:
weak_ref_tensor
(
value
)
for
key
,
value
in
output_hidden_or_intermediate_states
.
tensors
.
items
()
})
del
output_hidden_or_intermediate_states
# make sure `output_hidden_states` is deleted
# make sure `output_hidden_
or_intermediate_
states` is deleted
# in the graph's memory pool
gc
.
collect
()
torch
.
cuda
.
synchronize
()
...
...
@@ -1837,7 +1824,6 @@ class CUDAGraphRunner(nn.Module):
}
else
:
self
.
output_buffers
=
hidden_or_intermediate_states
return
hidden_or_intermediate_states
def
forward
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment