Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
31bff691
Unverified
Commit
31bff691
authored
Dec 19, 2023
by
Hanzhi Zhou
Committed by
GitHub
Dec 19, 2023
Browse files
Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207)
parent
ba4f8267
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
17 deletions
+38
-17
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+38
-17
No files found.
vllm/worker/model_runner.py
View file @
31bff691
...
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
...
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -52,6 +53,8 @@ class ModelRunner:
...
@@ -52,6 +53,8 @@ class ModelRunner:
# The shape of the cached block table will be
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
=
None
# Set after initial profiling.
self
.
graph_block_tables
=
None
# Set after initial profiling.
# cache in_wsl result
self
.
in_wsl
=
in_wsl
()
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
)
self
.
model
=
get_model
(
self
.
model_config
)
...
@@ -203,24 +206,29 @@ class ModelRunner:
...
@@ -203,24 +206,29 @@ class ModelRunner:
# When using CUDA graph, we don't need to make the tensors on the GPU
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
# because they will be eventually copied to the designated GPU buffer.
device
=
"cpu"
if
use_captured_graph
else
"cuda"
device
=
"cpu"
if
use_captured_graph
else
"cuda"
pin_memory
=
use_captured_graph
and
not
self
.
in_wsl
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
,
pin_memory
=
pin_memory
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
,
pin_memory
=
pin_memory
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
,
pin_memory
=
pin_memory
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
device
)
device
=
device
,
pin_memory
=
pin_memory
)
if
use_captured_graph
:
if
use_captured_graph
:
# The shape of graph_block_tables is
# The shape of graph_block_tables is
...
@@ -229,7 +237,7 @@ class ModelRunner:
...
@@ -229,7 +237,7 @@ class ModelRunner:
for
i
,
block_table
in
enumerate
(
block_tables
):
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
)
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
else
:
block_tables
=
_make_tensor_with_pad
(
block_tables
=
_make_tensor_with_pad
(
block_tables
,
block_tables
,
...
@@ -297,11 +305,11 @@ class ModelRunner:
...
@@ -297,11 +305,11 @@ class ModelRunner:
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sample_indices_start_idx
+=
num_seqs
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
pin_memory
=
not
self
.
in_wsl
)
categorized_sample_indices
=
{
categorized_sample_indices
=
{
t
:
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
t
:
_async_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
pin_memory
=
not
self
.
in_wsl
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
}
...
@@ -334,8 +342,6 @@ class ModelRunner:
...
@@ -334,8 +342,6 @@ class ModelRunner:
else
:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
input_tokens
,
input_positions
,
input_metadata
=
inputs
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Execute the model.
# Execute the model.
if
input_metadata
.
use_cuda_graph
:
if
input_metadata
.
use_cuda_graph
:
...
@@ -350,6 +356,9 @@ class ModelRunner:
...
@@ -350,6 +356,9 @@ class ModelRunner:
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -502,11 +511,14 @@ class CUDAGraphRunner:
...
@@ -502,11 +511,14 @@ class CUDAGraphRunner:
del
kv_caches
del
kv_caches
# Copy the input tensors to the input buffers.
# Copy the input tensors to the input buffers.
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
)
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
input_metadata
.
slot_mapping
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
input_metadata
.
slot_mapping
,
self
.
input_buffers
[
"context_lens"
].
copy_
(
input_metadata
.
context_lens
)
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
input_metadata
.
block_tables
)
self
.
input_buffers
[
"context_lens"
].
copy_
(
input_metadata
.
context_lens
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
input_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
# Run the graph.
self
.
graph
.
replay
()
self
.
graph
.
replay
()
...
@@ -529,9 +541,13 @@ def _make_tensor_with_pad(
...
@@ -529,9 +541,13 @@ def _make_tensor_with_pad(
pad
:
int
,
pad
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
Union
[
str
,
torch
.
device
]
=
"cuda"
,
device
:
Union
[
str
,
torch
.
device
]
=
"cuda"
,
pin_memory
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
padded_x
=
[
_pad_to_max
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
padded_x
=
[
_pad_to_max
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
,
pin_memory
=
pin_memory
and
str
(
device
)
==
"cpu"
)
def
_get_graph_batch_size
(
batch_size
:
int
)
->
int
:
def
_get_graph_batch_size
(
batch_size
:
int
)
->
int
:
...
@@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
...
@@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
return
4
return
4
else
:
else
:
return
(
batch_size
+
7
)
//
8
*
8
return
(
batch_size
+
7
)
//
8
*
8
def
_async_h2d
(
data
:
list
,
dtype
,
pin_memory
):
t
=
torch
.
tensor
(
data
,
dtype
=
dtype
,
pin_memory
=
pin_memory
)
return
t
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment