"vllm/vscode:/vscode.git/clone" did not exist on "bf6a3d0ff5a69e0a30567f2ad417530c002eaa4e"
Unverified Commit ddc369fb authored by tomeras91's avatar tomeras91 Committed by GitHub
Browse files

[Bugfix] Mamba cache Cuda Graph padding (#6214)

parent 185ad31f
import pytest import pytest
from vllm.worker.model_runner import _get_graph_batch_size
MODELS = ["ai21labs/Jamba-tiny-random"] MODELS = ["ai21labs/Jamba-tiny-random"]
...@@ -32,6 +34,32 @@ def test_models( ...@@ -32,6 +34,32 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
example_prompts.append(example_prompts[0])
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup( def test_state_cleanup(
......
...@@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module): ...@@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module):
key in kwargs key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
batch_size = len(request_ids_to_seq_ids) cg_batch_size = input_buffers['input_ids'].shape[0]
( (
current_mamba_cache, current_mamba_cache,
indices, indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size) cg_batch_size)
self.current_indices = indices self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"] finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids) self._release_mamba_cache(finished_requests_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment