Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9dbc5c4
Unverified
Commit
b9dbc5c4
authored
Mar 26, 2026
by
Divakar Verma
Committed by
GitHub
Mar 26, 2026
Browse files
[Mamba][APC] Add test case to compare apc outputs (#34977)
Signed-off-by:
Divakar Verma
<
divakar.verma@amd.com
>
parent
60af7b96
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
0 deletions
+54
-0
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+54
-0
No files found.
tests/models/language/generation/test_hybrid.py
View file @
b9dbc5c4
...
...
@@ -774,6 +774,60 @@ def test_apc_multiple_prompts_partial_cached_outputs(
)
# Test that outputs match whether prefix caching is enabled or not for mamba.
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"tiiuae/falcon-mamba-7b"
])
def
test_same_mamba_output_apc_on_vs_off
(
vllm_runner
,
model
:
str
,
)
->
None
:
num_logprobs
=
5
prompts
=
[
"hello what is one plus one what is one plus one what is one plus one the answer is"
,
# noqa: E501
"hello what is one plus one what is one plus one what is one plus one the answer is"
,
# noqa: E501
]
max_tokens
=
20
max_model_len
=
max
(
len
(
p
)
for
p
in
prompts
)
+
max_tokens
+
64
base_kwargs
=
_get_vllm_runner_params
(
model
,
max_model_len
)
base_kwargs
.
update
(
enforce_eager
=
True
,
block_size
=
16
,
seed
=
42
,
gpu_memory_utilization
=
0.8
)
# No prefix caching
kwargs_no_apc
=
{
**
base_kwargs
,
"enable_prefix_caching"
:
False
}
with
vllm_runner
(
**
kwargs_no_apc
)
as
vllm_model
:
outputs_no_apc
,
_
=
_get_vLLM_output
(
vllm_runner
,
kwargs_no_apc
,
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
vllm_model
=
vllm_model
,
)
# With prefix caching
kwargs_with_apc
=
{
**
base_kwargs
,
"enable_prefix_caching"
:
True
,
"mamba_block_size"
:
16
,
}
with
vllm_runner
(
**
kwargs_with_apc
)
as
vllm_model
:
outputs_with_apc
,
_
=
_get_vLLM_output
(
vllm_runner
,
kwargs_with_apc
,
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
vllm_model
=
vllm_model
,
)
check_logprobs_close
(
outputs_0_lst
=
outputs_no_apc
[
0
],
outputs_1_lst
=
outputs_with_apc
[
0
],
name_0
=
"vllm_no_apc"
,
name_1
=
"vllm_with_apc"
,
)
# we have to use a real large model to get reasonable results
# the model can't be a hybrid model as we need block_size 16
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"tiiuae/falcon-mamba-7b"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment