"vscode:/vscode.git/clone" did not exist on "d5e0fca264f7d277ed372f7c075827ed9a0c5e7e"
Unverified Commit 8e8a3bec authored by Lalithnarayan C's avatar Lalithnarayan C Committed by GitHub
Browse files

[ZenCPU] Make PT Backport Patch Accessible to vLLM (#38205)


Signed-off-by: default avatarLalithnarayan C <Lalithnarayan.C@amd.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 1dfd64c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the FxGraphCachePickler.dumps ValueError patch in env_override.py.
Validates that _apply_fxgraphcache_pickle_patch correctly wraps a pickler's
dumps method to convert ValueError into a bypass exception, without affecting
other exception types or normal return values.
"""
import pytest
from vllm.env_override import _apply_fxgraphcache_pickle_patch
class _BypassStub(Exception):
"""Stand-in for BypassFxGraphCache in unit tests."""
class TestApplyFxgraphcachePicklePatch:
def test_valueerror_converted_to_bypass(self):
class Pickler:
def dumps(self, obj):
raise ValueError("can't serialize blocked layout")
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
with pytest.raises(_BypassStub, match="Failed to pickle cache key"):
Pickler().dumps(object())
def test_original_valueerror_chained(self):
class Pickler:
def dumps(self, obj):
raise ValueError("bad tensor layout")
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
with pytest.raises(_BypassStub) as exc_info:
Pickler().dumps(object())
cause = exc_info.value.__cause__
assert isinstance(cause, ValueError)
assert str(cause) == "bad tensor layout"
def test_non_valueerror_propagates(self):
class Pickler:
def dumps(self, obj):
raise TypeError("unexpected type")
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
with pytest.raises(TypeError, match="unexpected type"):
Pickler().dumps(object())
def test_normal_return_preserved(self):
sentinel = b"serialized-graph-key"
class Pickler:
def dumps(self, obj):
return sentinel
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
assert Pickler().dumps(object()) is sentinel
def test_idempotent(self):
class Pickler:
def dumps(self, obj):
return b"ok"
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
first_dumps = Pickler.dumps
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
assert Pickler.dumps is first_dumps
def test_sentinel_attribute_set(self):
class Pickler:
def dumps(self, obj):
return b"ok"
assert not hasattr(Pickler.dumps, "_vllm_patched")
assert not getattr(Pickler, "_vllm_fxgraph_dumps_patched", False)
_apply_fxgraphcache_pickle_patch(Pickler, _BypassStub)
assert Pickler.dumps._vllm_patched is True # type: ignore[attr-defined]
assert Pickler._vllm_fxgraph_dumps_patched is True # type: ignore[attr-defined]
def test_patch_applied_in_current_environment():
"""Integration: verify patch state matches current torch version."""
from torch._inductor.codecache import FxGraphCachePickler
from vllm.utils.torch_utils import is_torch_equal_or_newer
should_be_patched = is_torch_equal_or_newer(
"2.10.0"
) and not is_torch_equal_or_newer("2.11.0")
assert getattr(FxGraphCachePickler, "_vllm_fxgraph_dumps_patched", False) == (
should_be_patched
)
assert hasattr(FxGraphCachePickler.dumps, "_vllm_patched") == should_be_patched
...@@ -586,3 +586,52 @@ if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"): ...@@ -586,3 +586,52 @@ if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"):
return runtime_env return runtime_env
GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env
# ===================================================
# torch 2.10 FxGraphCachePickler.dumps ValueError fix
# ===================================================
# PyTorch 2.10's FxGraphCachePickler.dumps() doesn't catch ValueError,
# causing torch.compile cache failures when tensors with non-standard
# layouts (e.g. blocked-layout prepacked weights) are serialized.
# PyTorch mainline fixed this in pytorch/pytorch#176557 (merged 2026-03-04).
# This is a thin backport for 2.10 users; remove once 2.10 is dropped.
def _apply_fxgraphcache_pickle_patch(pickler_cls, bypass_cls):
"""Wrap pickler_cls.dumps to convert ValueError into bypass_cls.
Idempotent: sets `_vllm_fxgraph_dumps_patched` on the class after the
first apply to prevent re-application. The wrapper function is also
marked with `_vllm_patched` as an additional safeguard.
"""
if getattr(pickler_cls, "_vllm_fxgraph_dumps_patched", False):
return
original_dumps = pickler_cls.dumps
if hasattr(original_dumps, "_vllm_patched"):
return
def patched_dumps(self, obj):
try:
return original_dumps(self, obj)
except ValueError as e:
raise bypass_cls("Failed to pickle cache key") from e
patched_dumps._vllm_patched = True # type: ignore[attr-defined]
pickler_cls.dumps = patched_dumps
pickler_cls._vllm_fxgraph_dumps_patched = True # type: ignore[attr-defined]
def _patch_fxgraphcache_pickle_if_needed():
"""Apply FxGraphCachePickler.dumps ValueError backport when on torch 2.10.x."""
from vllm.utils.torch_utils import is_torch_equal_or_newer
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
return
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler
_apply_fxgraphcache_pickle_patch(FxGraphCachePickler, BypassFxGraphCache)
_patch_fxgraphcache_pickle_if_needed()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.config import VllmConfig
class ZenCpuPlatform(CpuPlatform): class ZenCpuPlatform(CpuPlatform):
"""CPU platform with AMD Zen (ZenDNN/zentorch) optimizations. """CPU platform with AMD Zen (ZenDNN/zentorch) optimizations.
...@@ -28,40 +22,3 @@ class ZenCpuPlatform(CpuPlatform): ...@@ -28,40 +22,3 @@ class ZenCpuPlatform(CpuPlatform):
def is_zen_cpu(self) -> bool: def is_zen_cpu(self) -> bool:
# is_cpu() also returns True for this platform (inherited from CpuPlatform). # is_cpu() also returns True for this platform (inherited from CpuPlatform).
return True return True
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
super().check_and_update_config(vllm_config)
cls._apply_pytorch_backports()
@classmethod
def _apply_pytorch_backports(cls):
"""Backport PyTorch mainline fixes missing in 2.10.
PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't
catch ValueError, causing torch.compile cache misses. Remove this
once we drop PyTorch 2.10 support. PT mainline already has this fix.
"""
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
return
cls._patch_fxgraphcache_pickle()
@classmethod
def _patch_fxgraphcache_pickle(cls):
"""Backport mainline ValueError fix to FxGraphCachePickler.dumps()."""
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler
original_dumps = FxGraphCachePickler.dumps
if hasattr(original_dumps, "_zen_patched"):
return
def patched_dumps(self, obj):
try:
return original_dumps(self, obj)
except ValueError as e:
raise BypassFxGraphCache("Failed to pickle cache key") from e
patched_dumps._zen_patched = True # type: ignore[attr-defined]
FxGraphCachePickler.dumps = patched_dumps
logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment