Unverified Commit 5efa206a authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix `ExaoneMoeMTP` test that never ran in Transformers v4 (#36792)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 196802df
......@@ -247,6 +247,7 @@ def _compare_tp(
hf_config = get_config(model_id, trust_remote_code)
require_embed_inputs = model_info.require_embed_inputs
max_num_seqs = model_info.max_num_seqs
enable_prefix_caching = model_info.enable_prefix_caching
dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
......@@ -300,6 +301,8 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not enable_prefix_caching:
common_args.append("--no-enable-prefix-caching")
if require_embed_inputs:
common_args.extend(
[
......
......@@ -74,6 +74,8 @@ def run_test(
if model_info.require_embed_inputs:
for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"):
vllm_runner_kwargs_[k] = model_info.require_embed_inputs
if not model_info.enable_prefix_caching:
vllm_runner_kwargs_["enable_prefix_caching"] = False
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)
......
......@@ -72,6 +72,12 @@ class _HfExamplesInfo:
If False, we will use CUDA graph and eager execution in hybrid.
"""
enable_prefix_caching: bool = True
"""
Whether to enable prefix caching for the model. If True, we will test the model with
prefix caching enabled. If False, we will test the model without prefix caching.
"""
is_available_online: bool = True
"""
Set this to `False` if the name of this architecture no longer exists on
......@@ -1206,6 +1212,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"LGAI-EXAONE/K-EXAONE-236B-A23B",
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.1.0",
enable_prefix_caching=False,
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
......
......@@ -136,6 +136,10 @@ def can_initialize(
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
kwargs = {}
if not model_info.enable_prefix_caching:
kwargs["enable_prefix_caching"] = False
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
......@@ -165,6 +169,7 @@ def can_initialize(
hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment