Unverified Commit df7fe452 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the CI jobs on model import errors (#2072)

parent a7164b62
......@@ -8,7 +8,7 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
if is_flashinfer_available():
from flashinfer.sampling import (
......@@ -19,10 +19,6 @@ if is_flashinfer_available():
)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__)
......@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
exit(1) if crash_on_warning else None
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
......
......@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
crash_on_warnings,
get_zmq_socket,
kill_parent_process,
set_random_seed,
......@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
......@@ -662,21 +659,23 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
msg = (
"KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n"
)
exit(1) if crash_on_warning else None
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
msg = (
"Memory pool leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
exit(1) if crash_on_warning else None
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch
......
......@@ -20,6 +20,7 @@ import importlib
import importlib.resources
import json
import logging
import os
import pkgutil
from functools import lru_cache
from typing import Optional, Type
......@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_p2p_access_check,
......@@ -665,7 +667,9 @@ def import_model_classes():
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
......
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers, maybe_prefix
from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
self,
config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
cache_config=None,
):
super().__init__()
......@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel(
config=config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "model"),
prefix="model",
)
self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size
......
......@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
)
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment