Unverified Commit 2e3e3c86 authored by Vlad Tiberiu Mihailescu's avatar Vlad Tiberiu Mihailescu Committed by GitHub
Browse files

Export NaNs in logits to scheduler_stats if output is corrupted (#18777)


Signed-off-by: default avatarVlad Mihailescu <vtmihailescu@gmail.com>
parent 7e8977fc
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import random import random
import pytest import pytest
import torch
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner): ...@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id) assert _is_req_state_block_table_match(model_runner, req_id)
def test_get_nans_in_logits(model_runner):
req_ids = ("req_0", "req_1")
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
logits = torch.tensor([
[1.0, 2.0, 3.0],
[3.0, 2.0, 1.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 2}
logits = torch.tensor([
[1.0, 2.0, 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 2}
result = model_runner._get_nans_in_logits(logits=None)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 1, 'req_1': 0}
logits = torch.tensor([
[float('nan'), float('nan'), 2.0],
[1.0, 2.0, 3.0],
[float('nan'), 2.0, 3.0],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 2, 'req_1': 0}
def test_update_states_no_changes(model_runner): def test_update_states_no_changes(model_runner):
req_id = "req_0" req_id = "req_0"
......
...@@ -130,6 +130,7 @@ if TYPE_CHECKING: ...@@ -130,6 +130,7 @@ if TYPE_CHECKING:
VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_KV_CACHE_LAYOUT: Optional[str] = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -897,7 +898,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -897,7 +898,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
# leave the layout choice to the backend. Mind that backends may only # leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts. # implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT": "VLLM_KV_CACHE_LAYOUT":
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None) lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None),
# Enable checking whether the generated logits contain NaNs,
# indicating corrupted output. Useful for debugging low level bugs
# or bad hardware but it may add compute overhead.
"VLLM_COMPUTE_NANS_IN_LOGITS":
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -717,6 +717,7 @@ class Scheduler(SchedulerInterface): ...@@ -717,6 +717,7 @@ class Scheduler(SchedulerInterface):
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
new_running: list[Request] = [] new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
...@@ -810,6 +811,10 @@ class Scheduler(SchedulerInterface): ...@@ -810,6 +811,10 @@ class Scheduler(SchedulerInterface):
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids) req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None:
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
...@@ -972,6 +977,8 @@ class Scheduler(SchedulerInterface): ...@@ -972,6 +977,8 @@ class Scheduler(SchedulerInterface):
kv_cache_usage=self.kv_cache_manager.usage, kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats, prefix_cache_stats=prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats, spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted
for req in self.running),
) )
def make_spec_decoding_stats( def make_spec_decoding_stats(
......
...@@ -40,6 +40,8 @@ class SchedulerStats: ...@@ -40,6 +40,8 @@ class SchedulerStats:
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
num_corrupted_reqs: int = 0
@dataclass @dataclass
class LoRAStats: class LoRAStats:
......
...@@ -108,6 +108,9 @@ class ModelRunnerOutput: ...@@ -108,6 +108,9 @@ class ModelRunnerOutput:
finished_sending: Optional[set[str]] = None finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None finished_recving: Optional[set[str]] = None
# req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={}, req_id_to_index={},
...@@ -117,4 +120,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], ...@@ -117,4 +120,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
finished_sending=None, finished_sending=None,
finished_recving=None) finished_recving=None,
num_nans_in_logits=None)
...@@ -97,6 +97,10 @@ class Request: ...@@ -97,6 +97,10 @@ class Request:
# The number of tokens with prefix cache hits. # The number of tokens with prefix cache hits.
self.num_cached_tokens = -1 self.num_cached_tokens = -1
# The number of NaNs in logits. A value greater than 0
# indicates that the output is corrupted
self.num_nans_in_logits = 0
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None: if request.mm_inputs is not None:
...@@ -132,6 +136,10 @@ class Request: ...@@ -132,6 +136,10 @@ class Request:
self._output_token_ids.extend(token_ids) self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids)
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
return len(self._all_token_ids) return len(self._all_token_ids)
......
...@@ -1431,6 +1431,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1431,6 +1431,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize. # the requests one by one. Optimize.
discard_sampled_tokens_req_indices = [] discard_sampled_tokens_req_indices = []
...@@ -1601,6 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1601,6 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output=[], pooler_output=[],
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
) )
def kv_connector_no_forward( def kv_connector_no_forward(
...@@ -1826,6 +1831,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1826,6 +1831,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return prompt_logprobs_dict return prompt_logprobs_dict
def _get_nans_in_logits(
self,
logits: Optional[torch.Tensor],
) -> dict[str, int]:
try:
if logits is None:
return {req_id: 0 for req_id in self.input_batch.req_ids}
num_nans_in_logits = {}
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
for req_id in self.input_batch.req_ids:
req_index = self.input_batch.req_id_to_index[req_id]
num_nans_in_logits[req_id] = (
int(num_nans_for_index[req_index])
if num_nans_for_index is not None
and req_index < logits.shape[0] else 0)
return num_nans_in_logits
except IndexError:
return {}
@contextmanager @contextmanager
def maybe_randomize_inputs(self, input_ids: torch.Tensor): def maybe_randomize_inputs(self, input_ids: torch.Tensor):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment