Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5fe3f70
Unverified
Commit
d5fe3f70
authored
Feb 14, 2026
by
Thomas Parnell
Committed by
GitHub
Feb 14, 2026
Browse files
[Hybrid] Enable mamba prefix cache "align" mode with async scheduling (#33997)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
73391a1b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
77 additions
and
32 deletions
+77
-32
tests/v1/e2e/test_mamba_prefix_cache.py
tests/v1/e2e/test_mamba_prefix_cache.py
+58
-19
vllm/config/vllm.py
vllm/config/vllm.py
+0
-12
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+13
-0
vllm/v1/worker/mamba_utils.py
vllm/v1/worker/mamba_utils.py
+6
-1
No files found.
tests/v1/e2e/test_mamba_prefix_cache.py
View file @
d5fe3f70
...
@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
...
@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
),
),
logprobs_tensors
=
None
,
logprobs_tensors
=
None
,
)
)
num_sampled_tokens
=
spec_decode_metadata
.
cu_num_sampled_tokens
[
0
].
item
()
+
1
accpeted_tokens
=
prompt_token_ids
[
accpeted_tokens
=
prompt_token_ids
[
first_token_id_index
:
first_token_id_index
first_token_id_index
:
first_token_id_index
+
min
(
num_accepted_tokens
,
logits
.
shape
[
0
])
+
min
(
num_accepted_tokens
,
logits
.
shape
[
0
])
]
]
sampled_token_ids
=
accpeted_tokens
+
[
-
1
]
*
(
sampled_token_ids
=
accpeted_tokens
num_sampled_tokens
-
len
(
accpeted_tokens
)
)
return
SamplerOutput
(
return
SamplerOutput
(
sampled_token_ids
=
torch
.
tensor
(
sampled_token_ids
=
torch
.
tensor
(
[
sampled_token_ids
],
device
=
"cuda"
,
dtype
=
torch
.
int32
[
sampled_token_ids
],
device
=
"cuda"
,
dtype
=
torch
.
int32
...
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
...
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
first_token_id_index
:
first_token_id_index
+
num_speculative_tokens
first_token_id_index
:
first_token_id_index
+
num_speculative_tokens
]
]
]
]
return
proposed_draft_token_ids
next_token_ids
=
torch
.
tensor
(
prompt_token_ids
[
first_token_id_index
-
1
:
first_token_id_index
-
1
+
num_accepted_tokens
],
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
valid_sampled_tokens_count
=
torch
.
tensor
(
[
num_accepted_tokens
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
return
torch
.
tensor
(
proposed_draft_token_ids
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
fake_propose_draft_token_ids_fn
return
fake_propose_draft_token_ids_fn
...
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
...
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
def
get_fake_execute_model_fn
(
original_execute_model_fn
:
Callable
):
def
get_fake_execute_model_fn
(
original_execute_model_fn
:
Callable
):
last_num_computed_tokens
=
0
last_num_computed_tokens
=
0
num_prompt_tokens
=
None
def
fake_execute_model_fn
(
def
fake_execute_model_fn
(
self
:
GPUModelRunner
,
self
:
GPUModelRunner
,
...
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
...
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
mamba_group_id
mamba_group_id
].
layer_names
[
0
]
].
layer_names
[
0
]
nonlocal
last_num_computed_tokens
nonlocal
last_num_computed_tokens
nonlocal
num_prompt_tokens
if
(
len
(
scheduler_output
.
scheduled_new_reqs
)
>
0
and
scheduler_output
.
scheduled_new_reqs
[
0
].
prompt_token_ids
is
not
None
):
# record number of prompt tokens
num_prompt_tokens
=
len
(
scheduler_output
.
scheduled_new_reqs
[
0
].
prompt_token_ids
)
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
)
>
0
:
if
len
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
)
>
0
:
num_computed_tokens
=
(
num_computed_tokens
=
(
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
0
]
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
[
0
]
)
)
if
(
self
.
num_spec_tokens
and
num_prompt_tokens
is
not
None
and
num_computed_tokens
>
num_prompt_tokens
):
# NOTE (tdoublep) with async scheduling, the scheduler does not have an
# accurate measure of the number of computed tokens; we need to subtract
# the number of reject tokens from the previous timestep.
num_computed_tokens
-=
num_speculative_tokens
+
1
-
num_accepted_tokens
if
(
if
(
num_computed_tokens
//
BLOCK_SIZE
num_computed_tokens
//
BLOCK_SIZE
>
last_num_computed_tokens
//
BLOCK_SIZE
>
last_num_computed_tokens
//
BLOCK_SIZE
...
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions
=
[
step_actions
=
[
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
557
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
0
,
1
),
(
-
1
,
-
1
)),
StepAction
(
557
,
4
,
[],
(
0
,
1
),
(
-
1
,
-
1
)),
StepAction
(
558
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
558
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
559
,
4
,
[],
(
-
1
,
-
1
),
(
1
,
0
)),
StepAction
(
559
,
4
,
[],
(
-
1
,
-
1
),
(
1
,
0
)),
StepAction
(
560
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
560
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
...
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions
=
[
step_actions
=
[
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
558
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
1
,
1
),
(
2
,
0
)),
StepAction
(
558
,
4
,
[],
(
1
,
1
),
(
2
,
0
)),
StepAction
(
560
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
560
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
562
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
562
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
...
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
557
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
1
,
1
),
(
-
1
,
-
1
)),
StepAction
(
557
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
1
,
1
),
(
-
1
,
-
1
)),
StepAction
(
559
,
4
,
[],
(
-
1
,
-
1
),
(
1
,
0
)),
StepAction
(
559
,
4
,
[],
(
-
1
,
-
1
),
(
1
,
0
)),
StepAction
(
561
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
561
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
563
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_3_1"
:
TestConfig
(
"accept_3_1"
:
TestConfig
(
...
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions
=
[
step_actions
=
[
StepAction
(
0
,
553
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
553
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
553
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
553
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
556
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
559
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
2
,
1
),
(
1
,
0
)),
StepAction
(
559
,
4
,
[],
(
2
,
1
),
(
1
,
0
)),
StepAction
(
562
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
562
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
565
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_3_2"
:
TestConfig
(
"accept_3_2"
:
TestConfig
(
...
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction
(
0
,
555
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
555
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
558
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
2
,
1
),
(
2
,
0
)),
StepAction
(
558
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
2
,
1
),
(
2
,
0
)),
StepAction
(
561
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
561
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
564
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_4_1"
:
TestConfig
(
"accept_4_1"
:
TestConfig
(
...
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction
(
0
,
553
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
553
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
553
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
553
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
557
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
3
,
0
)),
StepAction
(
557
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
3
,
0
)),
StepAction
(
561
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
561
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
565
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
565
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_4_2"
:
TestConfig
(
"accept_4_2"
:
TestConfig
(
...
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
554
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
554
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
558
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
2
,
0
)),
StepAction
(
558
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
2
,
0
)),
StepAction
(
562
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
562
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
566
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
566
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_4_3"
:
TestConfig
(
"accept_4_3"
:
TestConfig
(
...
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
...
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction
(
0
,
555
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
0
,
555
,
[
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
555
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
559
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
1
,
0
)),
StepAction
(
559
,
4
,
[
1
,
1
,
1
,
1
,
1
],
(
3
,
1
),
(
1
,
0
)),
StepAction
(
563
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
563
,
4
,
[],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
StepAction
(
567
,
4
,
[
0
,
1
,
1
,
1
,
1
],
(
-
1
,
-
1
),
(
-
1
,
-
1
)),
],
],
),
),
"accept_4_4"
:
TestConfig
(
"accept_4_4"
:
TestConfig
(
...
...
vllm/config/vllm.py
View file @
d5fe3f70
...
@@ -648,11 +648,6 @@ class VllmConfig:
...
@@ -648,11 +648,6 @@ class VllmConfig:
"`external_launcher` distributed executor backend, but you chose "
"`external_launcher` distributed executor backend, but you chose "
f
"`
{
executor_backend
}
`."
f
"`
{
executor_backend
}
`."
)
)
if
self
.
cache_config
.
mamba_cache_mode
!=
"none"
:
raise
ValueError
(
"Currently, async scheduling is not compatible with "
"prefix caching for Mamba models."
)
elif
self
.
scheduler_config
.
async_scheduling
is
None
:
elif
self
.
scheduler_config
.
async_scheduling
is
None
:
# Enable async scheduling unless there is an incompatible option.
# Enable async scheduling unless there is an incompatible option.
if
(
if
(
...
@@ -685,13 +680,6 @@ class VllmConfig:
...
@@ -685,13 +680,6 @@ class VllmConfig:
scope
=
"local"
,
scope
=
"local"
,
)
)
self
.
scheduler_config
.
async_scheduling
=
False
self
.
scheduler_config
.
async_scheduling
=
False
elif
self
.
cache_config
.
mamba_cache_mode
!=
"none"
:
logger
.
warning_once
(
"Async scheduling is not compatible with "
"prefix caching for Mamba models and will be disabled."
,
scope
=
"local"
,
)
self
.
scheduler_config
.
async_scheduling
=
False
else
:
else
:
self
.
scheduler_config
.
async_scheduling
=
True
self
.
scheduler_config
.
async_scheduling
=
True
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
d5fe3f70
...
@@ -814,6 +814,14 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -814,6 +814,14 @@ class MambaManager(SingleTypeKVCacheManager):
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
assert
isinstance
(
self
.
kv_cache_spec
,
MambaSpec
)
assert
isinstance
(
self
.
kv_cache_spec
,
MambaSpec
)
# NOTE (tdoublep) with async scheduling, the num_computed_tokens can contain
# draft tokens from the previous step that may or may not be rejected later.
# This can make us think we are further ahead in the sequence than we actually
# are, so let's assume that all tokens are rejected so we don't free blocks
# that we might actually need.
num_computed_tokens
=
max
(
0
,
num_computed_tokens
-
self
.
num_speculative_blocks
)
super
().
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
super
().
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
if
self
.
mamba_cache_mode
==
"align"
:
if
self
.
mamba_cache_mode
==
"align"
:
# `last_state_block_idx` refers to the block index allocated two steps ago.
# `last_state_block_idx` refers to the block index allocated two steps ago.
...
@@ -879,6 +887,9 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -879,6 +887,9 @@ class MambaManager(SingleTypeKVCacheManager):
# We can ignore lookahead tokens because current draft models don't have
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
# mamba layers.
num_tokens
=
num_tokens_main_model
num_tokens
=
num_tokens_main_model
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
# num_tokens can include draft tokens that will later be rejected.
num_required_blocks
=
(
num_required_blocks
=
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
self
.
num_speculative_blocks
cdiv
(
num_tokens
,
self
.
block_size
)
+
self
.
num_speculative_blocks
)
)
...
@@ -922,6 +933,8 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -922,6 +933,8 @@ class MambaManager(SingleTypeKVCacheManager):
# mamba layers.
# mamba layers.
num_tokens
=
num_tokens_main_model
num_tokens
=
num_tokens_main_model
req_blocks
:
list
[
KVCacheBlock
]
=
self
.
req_to_blocks
[
request_id
]
req_blocks
:
list
[
KVCacheBlock
]
=
self
.
req_to_blocks
[
request_id
]
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
# num_tokens can include draft tokens that will later be rejected.
num_required_blocks
=
(
num_required_blocks
=
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
self
.
num_speculative_blocks
cdiv
(
num_tokens
,
self
.
block_size
)
+
self
.
num_speculative_blocks
)
)
...
...
vllm/v1/worker/mamba_utils.py
View file @
d5fe3f70
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc
,
MambaStateCopyFunc
,
)
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
...
@@ -142,7 +143,11 @@ def preprocess_mamba(
...
@@ -142,7 +143,11 @@ def preprocess_mamba(
# if num_computed_tokens is 0, prev_state_idx will be -1
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx
=
(
req_state
.
num_computed_tokens
-
1
)
//
block_size
prev_state_idx
=
(
req_state
.
num_computed_tokens
-
1
)
//
block_size
num_blocks
=
len
(
req_state
.
block_ids
[
mamba_group_ids
[
0
]])
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_blocks
:
int
=
(
cdiv
(
req_state
.
num_computed_tokens
+
num_scheduled_tokens
,
block_size
)
+
num_speculative_blocks
)
# We always save the current running state at the last
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# (1 + num_speculative_blocks) block.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment