Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ddc369fb
Unverified
Commit
ddc369fb
authored
Jul 08, 2024
by
tomeras91
Committed by
GitHub
Jul 08, 2024
Browse files
[Bugfix] Mamba cache Cuda Graph padding (#6214)
parent
185ad31f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
2 deletions
+30
-2
tests/models/test_jamba.py
tests/models/test_jamba.py
+28
-0
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+2
-2
No files found.
tests/models/test_jamba.py
View file @
ddc369fb
import
pytest
from
vllm.worker.model_runner
import
_get_graph_batch_size
MODELS
=
[
"ai21labs/Jamba-tiny-random"
]
...
...
@@ -32,6 +34,32 @@ def test_models(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
def
test_mamba_cache_cg_padding
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
_get_graph_batch_size
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
try
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
except
RuntimeError
:
pytest
.
fail
(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_state_cleanup
(
...
...
vllm/model_executor/models/jamba.py
View file @
ddc369fb
...
...
@@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module):
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
batch_size
=
len
(
request_ids_to_seq_ids
)
cg_
batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
current_mamba_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
)
cg_
batch_size
)
self
.
current_indices
=
indices
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment