Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef527be0
Unverified
Commit
ef527be0
authored
Aug 05, 2024
by
Cody Yu
Committed by
GitHub
Aug 05, 2024
Browse files
[MISC] Use non-blocking transfer in prepare_input (#7172)
parent
89b8db6b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
49 deletions
+43
-49
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+12
-15
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+11
-12
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+12
-15
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-7
No files found.
vllm/attention/backends/flash_attn.py
View file @
ef527be0
...
@@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
@@ -310,7 +310,8 @@ class FlashAttentionMetadataBuilder(
...
@@ -310,7 +310,8 @@ class FlashAttentionMetadataBuilder(
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
=
device
,
non_blocking
=
True
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
@@ -320,15 +321,15 @@ class FlashAttentionMetadataBuilder(
...
@@ -320,15 +321,15 @@ class FlashAttentionMetadataBuilder(
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
seq_lens_tensor
=
async_
tensor
_h2d
(
seq_lens
,
torch
.
int
,
device
,
dtype
=
torch
.
int
,
self
.
runner
.
pin_memory
)
device
=
device
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
self
.
runner
.
pin_memory
)
dtype
=
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -344,10 +345,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -344,10 +345,6 @@ class FlashAttentionMetadataBuilder(
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
FlashAttentionMetadata
(
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
...
vllm/attention/backends/flashinfer.py
View file @
ef527be0
...
@@ -21,7 +21,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -21,7 +21,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
@@ -356,7 +357,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -356,7 +357,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
...
@@ -371,12 +373,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -371,12 +373,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
device
=
device
)
self
.
runner
.
pin_memory
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
dtype
=
torch
.
long
,
self
.
runner
.
pin_memory
)
device
=
device
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -392,10 +395,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -392,10 +395,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
device
=
"cpu"
,
...
...
vllm/attention/backends/utils.py
View file @
ef527be0
...
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
...
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
# Error string(s) for encoder/decoder
# Error string(s) for encoder/decoder
# unsupported attention scenarios
# unsupported attention scenarios
...
@@ -181,7 +181,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -181,7 +181,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
@@ -191,15 +192,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -191,15 +192,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
)
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
seq_lens_tensor
=
async_
tensor
_h2d
(
seq_lens
,
torch
.
int
,
device
,
dtype
=
torch
.
int
,
self
.
runner
.
pin_memory
)
device
=
device
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
self
.
runner
.
pin_memory
)
dtype
=
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -215,10 +216,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -215,10 +216,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
self
.
_metadata_cls
(
# type: ignore
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
...
vllm/worker/model_runner.py
View file @
ef527be0
...
@@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
...
@@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
flatten_2d_lists
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
)
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -549,12 +549,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -549,12 +549,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions.
# Tokens and positions.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
assert
self
.
runner
.
device
is
not
None
dtype
=
torch
.
long
,
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
runner
.
device
,
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
self
.
runner
.
pin_memory
)
dtype
=
torch
.
long
,
input_positions_tensor
=
async_tensor_h2d
(
input_positions
,
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
# Sequence and query lengths.
# Sequence and query lengths.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment