Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f366f633
Unverified
Commit
f366f633
authored
Aug 16, 2024
by
William Lin
Committed by
GitHub
Aug 16, 2024
Browse files
[spec decode] [4/N] Move update_flash_attn_metadata to attn backend (#7571)
Co-authored-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
855866ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
33 deletions
+49
-33
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+45
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+1
-33
No files found.
vllm/attention/backends/abstract.py
View file @
f366f633
...
...
@@ -75,6 +75,9 @@ class AttentionBackend(ABC):
)
->
None
:
raise
NotImplementedError
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
raise
NotImplementedError
@
dataclass
class
AttentionMetadata
:
...
...
vllm/attention/backends/flash_attn.py
View file @
f366f633
...
...
@@ -297,6 +297,51 @@ class FlashAttentionMetadata(AttentionMetadata):
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
...
...
vllm/spec_decode/draft_model_runner.py
View file @
f366f633
...
...
@@ -97,38 +97,6 @@ class TP1DraftModelRunner(ModelRunner):
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
num_queries
):
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
attn_metadata
.
use_cuda_graph
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_prefill_tokens
==
0
assert
attn_metadata
.
num_decode_tokens
==
num_seqs
assert
attn_metadata
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
len
(
attn_metadata
.
seq_lens
)
==
num_seqs
assert
attn_metadata
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
attn_metadata
.
max_query_len
==
1
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
attn_metadata
.
seq_lens
)
assert
attn_metadata
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
attn_metadata
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
attn_metadata
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
attn_metadata
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
attn_metadata
.
seq_lens
[
i
]
+=
1
attn_metadata
.
max_decode_seq_len
=
max
(
attn_metadata
.
seq_lens
)
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
...
...
@@ -166,7 +134,7 @@ class TP1DraftModelRunner(ModelRunner):
# Update attn_metadata
attn_metadata
=
model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
self
.
_update_flash_attn_metadata
(
attn_metadata
,
num_seqs
,
num_queries
)
attn_metadata
.
advance_step
(
num_seqs
,
num_queries
)
# Update GPU tensors
ops
.
advance_step
(
num_seqs
=
num_seqs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment