Unverified Commit df7fe452 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the CI jobs on model import errors (#2072)

parent a7164b62
...@@ -8,7 +8,7 @@ from torch import nn ...@@ -8,7 +8,7 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer.sampling import ( from flashinfer.sampling import (
...@@ -19,10 +19,6 @@ if is_flashinfer_available(): ...@@ -19,10 +19,6 @@ if is_flashinfer_available():
) )
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,7 +42,8 @@ class Sampler(nn.Module): ...@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits = torch.where( logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits torch.isnan(logits), torch.full_like(logits, -1e5), logits
) )
exit(1) if crash_on_warning else None if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
......
...@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs ...@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
crash_on_warnings,
get_zmq_socket, get_zmq_socket,
kill_parent_process, kill_parent_process,
set_random_seed, set_random_seed,
...@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback ...@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
# Test retract decode # Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true" test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
...@@ -662,21 +659,23 @@ class Scheduler: ...@@ -662,21 +659,23 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_tokens: if available_size != self.max_total_num_tokens:
warnings.warn( msg = (
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!" "KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n"
) )
exit(1) if crash_on_warning else None warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn( msg = (
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!" "Memory pool leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
) )
exit(1) if crash_on_warning else None warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
def get_next_batch_to_run(self): def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
......
...@@ -20,6 +20,7 @@ import importlib ...@@ -20,6 +20,7 @@ import importlib
import importlib.resources import importlib.resources
import json import json
import logging import logging
import os
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Type
...@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
...@@ -665,7 +667,9 @@ def import_model_classes(): ...@@ -665,7 +667,9 @@ def import_model_classes():
try: try:
module = importlib.import_module(name) module = importlib.import_module(name)
except Exception as e: except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}") logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue continue
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
entry = module.EntryClass entry = module.EntryClass
......
import math import math
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Phi3Config from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers, maybe_prefix from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
self, self,
config: Phi3Config, config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", cache_config=None,
): ):
super().__init__() super().__init__()
...@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel( self.model = Phi3SmallModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "model"), prefix="model",
) )
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
......
...@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity(): ...@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
raise RuntimeError( raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
) )
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment