Unverified Commit 853c371f authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[V1][Mamba] - Enable V1 by default for Mamba Models (#23650)


Signed-off-by: default avatarasafg <39553475+Josephasafg@users.noreply.github.com>
parent 8bf6266a
...@@ -100,6 +100,8 @@ def test_models( ...@@ -100,6 +100,8 @@ def test_models(
else: else:
hf_outputs = None hf_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS: if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
...@@ -108,11 +110,7 @@ def test_models( ...@@ -108,11 +110,7 @@ def test_models(
vllm_v0_outputs = None vllm_v0_outputs = None
if model in V1_SUPPORTED_MODELS: if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
else: else:
...@@ -137,7 +135,7 @@ def test_models( ...@@ -137,7 +135,7 @@ def test_models(
) )
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_batching( def test_batching(
...@@ -147,10 +145,6 @@ def test_batching( ...@@ -147,10 +145,6 @@ def test_batching(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
if model in V0_UNSUPPORTED_MODELS:
pytest.skip(
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
try: try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
...@@ -188,16 +182,19 @@ def test_chunked_prefill( ...@@ -188,16 +182,19 @@ def test_chunked_prefill(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
chunked_prefill_token_size: int, chunked_prefill_token_size: int,
monkeypatch,
) -> None: ) -> None:
max_num_seqs = chunked_prefill_token_size max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model, with vllm_runner(model,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model: max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy_logprobs(example_prompts, chunked = vllm_model.generate_greedy_logprobs(
max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, with vllm_runner(model,
enable_chunked_prefill=False, enable_chunked_prefill=False,
...@@ -281,10 +278,13 @@ def test_models_preemption_recompute( ...@@ -281,10 +278,13 @@ def test_models_preemption_recompute(
example_prompts, example_prompts,
model: str, model: str,
max_tokens: int, max_tokens: int,
monkeypatch,
) -> None: ) -> None:
""" """
Tests that outputs are identical with and w/o preemptions (recompute). Tests that outputs are identical with and w/o preemptions (recompute).
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
scheduler = vllm_model.llm.llm_engine.scheduler[0] scheduler = vllm_model.llm.llm_engine.scheduler[0]
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
...@@ -292,7 +292,8 @@ def test_models_preemption_recompute( ...@@ -292,7 +292,8 @@ def test_models_preemption_recompute(
example_prompts, max_tokens) example_prompts, max_tokens)
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs, outputs_0_lst=preempt_vllm_outputs,
...@@ -402,6 +403,8 @@ def test_full_cuda_graph( ...@@ -402,6 +403,8 @@ def test_full_cuda_graph(
else: else:
hf_outputs = None hf_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS: if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
...@@ -409,15 +412,7 @@ def test_full_cuda_graph( ...@@ -409,15 +412,7 @@ def test_full_cuda_graph(
else: else:
vllm_v0_outputs = None vllm_v0_outputs = None
with monkeypatch.context() as m: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
compilation_config={'full_cuda_graph': True},
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
...@@ -466,21 +461,17 @@ def test_fp32_state( ...@@ -466,21 +461,17 @@ def test_fp32_state(
else: else:
hf_outputs = None hf_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model: mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32", mamba_ssm_cache_dtype="float32") as vllm_model:
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
......
...@@ -1463,11 +1463,6 @@ class EngineArgs: ...@@ -1463,11 +1463,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# V1 mamba models are unoptimized.
if model_config.has_inner_state and _warn_or_fallback(
feature_name="Mamba"):
return False
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills if (self.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
......
...@@ -417,4 +417,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -417,4 +417,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GptOssForCausalLM": GptOssForCausalLMConfig, "GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig, "MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment