Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e5ec35b
Unverified
Commit
9e5ec35b
authored
Sep 19, 2024
by
William Lin
Committed by
GitHub
Sep 19, 2024
Browse files
[bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (#8474)
parent
18ae428a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
2 deletions
+58
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+57
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+1
-1
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
9e5ec35b
"""Attention layer ROCm GPUs."""
"""Attention layer ROCm GPUs."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -15,6 +15,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
...
@@ -15,6 +15,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE_ROCM
=
512
_PARTITION_SIZE_ROCM
=
512
...
@@ -180,6 +183,59 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -180,6 +183,59 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
ROCmFlashAttentionMetadataBuilder
(
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
...
...
vllm/worker/multi_step_model_runner.py
View file @
9e5ec35b
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"flash-attn"
,
"flashinfer"
]
MULTI_STEP_ATTENTION_BACKENDS
=
[
"flash-attn"
,
"rocm-flash-attn"
,
"flashinfer"
]
def
seq_output_builder
():
def
seq_output_builder
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment