Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dbcbbeb
Unverified
Commit
4dbcbbeb
authored
Nov 04, 2024
by
Yang Zheng
Committed by
GitHub
Nov 04, 2024
Browse files
[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)
Co-authored-by:
Yang Zheng(SW)(Alex)
<
you@example.com
>
parent
b67feb12
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
36 deletions
+20
-36
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+10
-18
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-18
No files found.
vllm/attention/backends/flash_attn.py
View file @
4dbcbbeb
"""Attention layer with FlashAttention."""
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
...
...
@@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
...
...
@@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
_tensor
,
seq_start_loc
=
seq_start_loc
_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
...
...
vllm/attention/backends/utils.py
View file @
4dbcbbeb
"""Attention backend utils"""
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
import
numpy
as
np
...
...
@@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
...
...
@@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
...
...
@@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
_tensor
,
seq_start_loc
=
seq_start_loc
_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment