test_startup_error.py 4.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Test that we handle a startup Error and shutdown."""

5
6
import inspect

7
8
9
import pytest

from tests.utils import wait_for_gpu_memory_to_clear
10
11
12
13
from tests.v1.shutdown.utils import (
    SHUTDOWN_TEST_THRESHOLD_BYTES,
    SHUTDOWN_TEST_TIMEOUT_SEC,
)
14
15
16
17
from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
18
from vllm.platforms import current_platform
19
20
from vllm.v1.engine.async_llm import AsyncLLM

21
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
22
23
24
25
26
27
28
29
30
31
32


def evil_method(self, *args, **kwargs):
    """Evil method that raises an exception."""

    if get_tensor_model_parallel_rank() == 0:
        raise Exception("Simulated Error in startup!")

    return self.model(*args, **kwargs, intermediate_tensors=None)


33
34
35
36
37
38
39
40
41
42
43
44
@pytest.fixture
def rocm_evil_method(rocm_sitecustomize_factory, request):
    failing_method = request.getfixturevalue("failing_method")
    lines = [
        "from vllm.distributed import get_tensor_model_parallel_rank",
        "from vllm.model_executor.models.llama import LlamaForCausalLM",
        inspect.getsource(evil_method),
        f"LlamaForCausalLM.{failing_method} = {evil_method.__name__}",
    ]
    rocm_sitecustomize_factory(lines)


45
46
47
48
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
49
def test_async_llm_startup_error(
50
51
52
53
54
    monkeypatch,
    rocm_evil_method,
    model: str,
    tensor_parallel_size: int,
    failing_method: str,
55
) -> None:
56
57
58
59
    """Test that AsyncLLM propagates an __init__ error & frees memory.
    Test profiling (forward()) and load weights failures.
    AsyncLLM always uses an MP client.
    """
60
    if current_platform.device_count() < tensor_parallel_size:
61
62
63
64
65
        pytest.skip(reason="Not enough CUDA devices")

    # Monkeypatch an error in the model.
    monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)

66
67
68
    engine_args = AsyncEngineArgs(
        model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size
    )
69
70

    # Confirm we get an exception.
71
    with pytest.raises(Exception, match=r"initialization fail(ed|ure)"):
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        _ = AsyncLLM.from_engine_args(engine_args)

    # Confirm all the processes are cleaned up.
    wait_for_gpu_memory_to_clear(
        devices=list(range(tensor_parallel_size)),
        threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
    )


@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
86
87
def test_llm_startup_error(
    monkeypatch,
88
    rocm_evil_method,
89
90
91
92
93
    model: str,
    tensor_parallel_size: int,
    enable_multiprocessing: bool,
    failing_method: str,
) -> None:
94
95
96
97
    """Test that LLM propagates an __init__ error and frees memory.
    Test profiling (forward()) and load weights failures.
    TODO(andy) - LLM without multiprocessing.
    """
98
99
100
101
    # Skip non-Llama models since we monkeypatch LlamaForCausalLM specifically.
    # If MODELS list grows, each architecture needs its own test variant.
    if model != "JackFram/llama-68m":
        pytest.skip(reason="Only test JackFram/llama-68m")
102
    if current_platform.device_count() < tensor_parallel_size:
103
104
105
106
107
108
109
110
111
112
        pytest.skip(reason="Not enough CUDA devices")

    with monkeypatch.context() as m:
        MP_VALUE = "1" if enable_multiprocessing else "0"
        m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)

        # Monkeypatch an error in the model.
        monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)

        with pytest.raises(
113
            Exception,
114
            match=r"initialization fail(ed|ure)"
115
116
117
118
119
120
121
122
            if enable_multiprocessing
            else "Simulated Error in startup!",
        ):
            _ = LLM(
                model=model,
                enforce_eager=True,
                tensor_parallel_size=tensor_parallel_size,
            )
123
124
125
126
127
128

        # Confirm all the processes are cleaned up.
        wait_for_gpu_memory_to_clear(
            devices=list(range(tensor_parallel_size)),
            threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
        )