Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dbcbbeb
Unverified
Commit
4dbcbbeb
authored
Nov 04, 2024
by
Yang Zheng
Committed by
GitHub
Nov 04, 2024
Browse files
[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)
Co-authored-by:
Yang Zheng(SW)(Alex)
<
you@example.com
>
parent
b67feb12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
36 deletions
+20
-36
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+10
-18
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-18
No files found.
vllm/attention/backends/flash_attn.py
View file @
4dbcbbeb
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
...
@@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
if
use_captured_graph
:
...
@@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
...
@@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
device
,
self
.
runner
.
pin_memory
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
dtype
=
torch
.
int32
,
device
,
device
=
device
)
self
.
runner
.
pin_memory
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
dtype
=
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
device
=
device
)
placeholder_index_maps
=
{
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
self
.
multimodal_placeholder_maps
.
items
()
}
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
return
FlashAttentionMetadata
(
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
...
@@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
...
@@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
max_decode_query_len
=
max_decode_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
_tensor
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
_tensor
,
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
...
...
vllm/attention/backends/utils.py
View file @
4dbcbbeb
"""Attention backend utils"""
"""Attention backend utils"""
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
...
@@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
,
self
.
runner
.
pin_memory
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
dtype
=
torch
.
int32
,
device
,
device
=
device
)
self
.
runner
.
pin_memory
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
dtype
=
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
device
=
device
)
placeholder_index_maps
=
{
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
self
.
multimodal_placeholder_maps
.
items
()
}
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
return
self
.
_metadata_cls
(
# type: ignore
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
...
@@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
_tensor
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
_tensor
,
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment