Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
309aaef8
Unverified
Commit
309aaef8
authored
Jul 24, 2024
by
Cody Yu
Committed by
GitHub
Jul 24, 2024
Browse files
[Bugfix] Fix decode tokens w. CUDA graph (#6757)
parent
9e169a4c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
4 deletions
+31
-4
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+1
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+10
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+10
-1
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-1
No files found.
tests/worker/test_model_runner.py
View file @
309aaef8
...
...
@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for
_
in
range
(
expected_bs
-
len
(
seq_lens
)):
seq_lens
.
append
(
1
)
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
num_decode_tokens
==
len
(
seq_lens
)
start_idx
=
0
start_loc
=
[
start_idx
]
for
_
in
context_lens
:
...
...
vllm/attention/backends/flash_attn.py
View file @
309aaef8
...
...
@@ -272,7 +272,15 @@ class FlashAttentionMetadataBuilder(
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
...
...
@@ -297,7 +305,7 @@ class FlashAttentionMetadataBuilder(
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
vllm/attention/backends/flashinfer.py
View file @
309aaef8
...
...
@@ -320,6 +320,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
...
...
@@ -334,7 +343,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
vllm/attention/backends/utils.py
View file @
309aaef8
...
...
@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
...
...
@@ -173,7 +182,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment