Unverified Commit a65cb160 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[MISC] Dump model runner inputs when crashing (#8305)

parent 3fd2b0d2
......@@ -30,6 +30,15 @@ body:
</details>
validations:
required: true
- type: textarea
attributes:
label: Model Input Dumps
description: |
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
placeholder: |
Upload the dumped input file.
validations:
required: false
- type: textarea
attributes:
label: 🐛 Describe the bug
......
......@@ -3,12 +3,16 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import pickle
import re
import weakref
from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.utils import is_hip
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal
......@@ -64,3 +68,29 @@ def test_models(
name_0="hf",
name_1="vllm",
)
def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
side_effect=ValueError()):
with pytest.raises(ValueError) as exc_info:
vllm_runner("facebook/opt-125m",
dtype="half",
enforce_eager=False,
gpu_memory_utilization=0.7)
matches = re.search(r"input dumped to (.+).pkl",
str(exc_info.value))
assert matches is not None
filename = f"{matches.group(1)}.pkl"
with open(filename, "rb") as filep:
inputs = pickle.load(filep)
if any(key not in inputs for key in ("arg_1", "arg_2", "arg_3")):
raise AssertionError("Missing keys in dumped inputs. Dumped keys: "
f"{list(inputs.keys())}")
assert isinstance(inputs["arg_1"],
ModelInputForGPUWithSamplingMetadata)
finally:
os.remove(filename)
......@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -1489,6 +1489,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine=virtual_engine)
@torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
......
import dataclasses
import pickle
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
......@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
return tensor_dict
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
exclude_kwargs: Optional[List[str]] = None):
def _inner(func):
@wraps(func)
def _wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
for k, v in kwargs.items()
if k not in (exclude_kwargs or [])
}
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
pickle.dump(dumped_inputs, filep)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err
return _wrapper
return _inner
class BroadcastableModelInput(ABC):
@abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment