Commit 16c971db authored by Andreas Karatzas's avatar Andreas Karatzas Committed by khluu
Browse files

[CI] Fix PaddleOCR-VL HF test failure due to create_causal_mask API rename (#37328)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
(cherry picked from commit eaf7c9b9)
parent 262ddd0d
...@@ -777,6 +777,7 @@ VLM_TEST_SETTINGS = { ...@@ -777,6 +777,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
auto_cls=AutoModelForCausalLM, auto_cls=AutoModelForCausalLM,
patch_hf_runner=model_utils.paddleocr_vl_patch_hf_runner,
image_size_factors=[(0.25,)], image_size_factors=[(0.25,)],
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
......
...@@ -1149,6 +1149,31 @@ def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -1149,6 +1149,31 @@ def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def paddleocr_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches the HfRunner to fix create_causal_mask API mismatch.
The PaddleOCR-VL HF model passes `inputs_embeds` to create_causal_mask,
but transformers renamed this parameter to `input_embeds`.
"""
import sys
model_module = sys.modules.get(type(hf_model.model.model).__module__)
if model_module is None:
return hf_model
original_create_causal_mask = getattr(model_module, "create_causal_mask", None)
if original_create_causal_mask is None:
return hf_model
def patched_create_causal_mask(*args, **kwargs):
if "inputs_embeds" in kwargs:
kwargs["input_embeds"] = kwargs.pop("inputs_embeds")
return original_create_causal_mask(*args, **kwargs)
model_module.create_causal_mask = patched_create_causal_mask # type: ignore[attr-defined]
return hf_model
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni.""" """Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker thinker = hf_model.model.thinker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment