Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a939737
Unverified
Commit
9a939737
authored
Dec 10, 2024
by
Tyler Michael Smith
Committed by
GitHub
Dec 11, 2024
Browse files
[Bugfix] Fix Mamba multistep (#11071)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
134810b3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
2 deletions
+66
-2
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+63
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+3
-1
No files found.
vllm/attention/backends/placeholder_attn.py
View file @
9a939737
...
@@ -11,7 +11,8 @@ from vllm.attention.backends.utils import CommonAttentionState
...
@@ -11,7 +11,8 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
# Placeholder attention backend for models like Mamba and embedding models that
# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
# lack attention.
...
@@ -186,6 +187,67 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
...
@@ -186,6 +187,67 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
not
turn_prefills_into_decodes
,
\
(
"Multi-Step + Chunked-Prefill is not supported for attention-free"
"models. turn_prefills_into_decodes is a "
"Multi-Step + Chunked-Prefill specific parameter."
)
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
# Update sequences, masking off entries greater than num_queries
device
=
self
.
seq_lens_tensor
.
device
mask
=
torch
.
arange
(
self
.
seq_lens_tensor
.
size
(
0
),
device
=
device
)
<
num_queries
self
.
seq_lens_tensor
+=
mask
.
to
(
self
.
seq_lens_tensor
.
dtype
)
if
sampled_token_ids
is
not
None
:
model_input
.
input_tokens
.
masked_scatter_
(
mask
,
sampled_token_ids
[:
num_queries
])
class
PlaceholderAttentionMetadataBuilder
(
class
PlaceholderAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
PlaceholderAttentionMetadata
]):
AttentionMetadataBuilder
[
PlaceholderAttentionMetadata
]):
...
...
vllm/worker/multi_step_model_runner.py
View file @
9a939737
...
@@ -29,7 +29,9 @@ if TYPE_CHECKING:
...
@@ -29,7 +29,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER"
]
MULTI_STEP_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER"
,
"NO_ATTENTION"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
]
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment