Unverified Commit 15422ed3 authored by Ryan Rock's avatar Ryan Rock Committed by GitHub
Browse files

[CI/Build][Hardware][AMD] Fix v1/shutdown (#31997)


Signed-off-by: default avatarRyan Rock <ryan.rock@amd.com>
parent 8471b27d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Iterable
from pathlib import Path
import pytest
from vllm.platforms import current_platform
@pytest.fixture
def rocm_sitecustomize_factory(monkeypatch, tmp_path: Path):
"""Return a function that installs a given sitecustomize payload."""
if not current_platform.is_rocm():
return lambda _: None
def install(lines: Iterable[str]) -> None:
sc = tmp_path / "sitecustomize.py"
sc.write_text("\n".join(lines) + "\n")
monkeypatch.setenv(
"PYTHONPATH",
os.pathsep.join(filter(None, [str(tmp_path), os.getenv("PYTHONPATH")])),
)
return install
......@@ -3,6 +3,7 @@
"""Test that we handle an Error in model forward and shutdown."""
import asyncio
import inspect
import pytest
......@@ -38,11 +39,22 @@ def evil_forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
@pytest.fixture
def rocm_evil_forward(rocm_sitecustomize_factory):
lines = [
"from vllm.distributed import get_tensor_model_parallel_rank",
"from vllm.model_executor.models.llama import LlamaForCausalLM",
inspect.getsource(evil_forward),
f"LlamaForCausalLM.forward = {evil_forward.__name__}",
]
rocm_sitecustomize_factory(lines)
@pytest.mark.asyncio
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_model_error(
monkeypatch, tensor_parallel_size: int, model: str
monkeypatch, rocm_evil_forward, tensor_parallel_size: int, model: str
) -> None:
"""Test that AsyncLLM propagates a forward pass error and frees memory.
......@@ -104,7 +116,11 @@ async def test_async_llm_model_error(
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
def test_llm_model_error(
monkeypatch, tensor_parallel_size: int, enable_multiprocessing: bool, model: str
monkeypatch,
rocm_evil_forward,
tensor_parallel_size: int,
enable_multiprocessing: bool,
model: str,
) -> None:
"""Test that LLM propagates a forward pass error and frees memory.
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
......
......@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that we handle a startup Error and shutdown."""
import inspect
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
......@@ -28,12 +30,28 @@ def evil_method(self, *args, **kwargs):
return self.model(*args, **kwargs, intermediate_tensors=None)
@pytest.fixture
def rocm_evil_method(rocm_sitecustomize_factory, request):
failing_method = request.getfixturevalue("failing_method")
lines = [
"from vllm.distributed import get_tensor_model_parallel_rank",
"from vllm.model_executor.models.llama import LlamaForCausalLM",
inspect.getsource(evil_method),
f"LlamaForCausalLM.{failing_method} = {evil_method.__name__}",
]
rocm_sitecustomize_factory(lines)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_async_llm_startup_error(
monkeypatch, model: str, tensor_parallel_size: int, failing_method: str
monkeypatch,
rocm_evil_method,
model: str,
tensor_parallel_size: int,
failing_method: str,
) -> None:
"""Test that AsyncLLM propagates an __init__ error & frees memory.
Test profiling (forward()) and load weights failures.
......@@ -67,6 +85,7 @@ def test_async_llm_startup_error(
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_llm_startup_error(
monkeypatch,
rocm_evil_method,
model: str,
tensor_parallel_size: int,
enable_multiprocessing: bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment