Unverified Commit 9c939a3d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up metrics code (#1972)

parent 549e8b83
...@@ -391,8 +391,12 @@ class TokenizerManager: ...@@ -391,8 +391,12 @@ class TokenizerManager:
async with self.model_update_lock: async with self.model_update_lock:
# wait for the previous generation requests to finish # wait for the previous generation requests to finish
while len(self.rid_to_state) > 0: for i in range(3):
await asyncio.sleep(0.001) while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await asyncio.sleep(0.01)
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future() self.model_update_result = asyncio.Future()
......
...@@ -25,20 +25,16 @@ import json ...@@ -25,20 +25,16 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import re
import tempfile
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Dict, List, Optional, Union from typing import AsyncIterator, Dict, List, Optional, Union
import orjson
from starlette.routing import Mount
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp import aiohttp
import orjson
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
...@@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList ...@@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
add_prometheus_middleware,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
delete_directory, delete_directory,
...@@ -84,16 +81,13 @@ from sglang.srt.utils import ( ...@@ -84,16 +81,13 @@ from sglang.srt.utils import (
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -445,10 +439,6 @@ def launch_server( ...@@ -445,10 +439,6 @@ def launch_server(
1. The HTTP server and Tokenizer Manager both run in the main process. 1. The HTTP server and Tokenizer Manager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
""" """
if server_args.enable_metrics:
_set_prometheus_env()
launch_engine(server_args=server_args) launch_engine(server_args=server_args)
# Add api key authorization # Add api key authorization
...@@ -487,36 +477,6 @@ def launch_server( ...@@ -487,36 +477,6 @@ def launch_server(
t.join() t.join()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def _set_prometheus_env():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
)
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
...@@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
# Set prometheus env vars
if server_args.enable_metrics:
set_prometheus_multiproc_dir()
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
......
...@@ -22,10 +22,12 @@ import logging ...@@ -22,10 +22,12 @@ import logging
import os import os
import pickle import pickle
import random import random
import re
import resource import resource
import shutil import shutil
import signal import signal
import socket import socket
import tempfile
import time import time
import warnings import warnings
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
...@@ -41,6 +43,7 @@ import triton ...@@ -41,6 +43,7 @@ import triton
import zmq import zmq
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import ( from triton.runtime.cache import (
...@@ -752,3 +755,38 @@ def delete_directory(dirpath): ...@@ -752,3 +755,38 @@ def delete_directory(dirpath):
shutil.rmtree(dirpath) shutil.rmtree(dirpath)
except OSError as e: except OSError as e:
print(f"Warning: {dirpath} : {e.strerror}") print(f"Warning: {dirpath} : {e.strerror}")
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory
def set_prometheus_multiproc_dir():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
)
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
def add_prometheus_middleware(app):
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
...@@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback ...@@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8" DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
...@@ -404,7 +405,6 @@ def popen_launch_server( ...@@ -404,7 +405,6 @@ def popen_launch_server(
other_args: tuple = (), other_args: tuple = (),
env: Optional[dict] = None, env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None, return_stdout_stderr: Optional[tuple] = None,
enable_metrics: bool = False,
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -423,8 +423,6 @@ def popen_launch_server( ...@@ -423,8 +423,6 @@ def popen_launch_server(
] ]
if api_key: if api_key:
command += ["--api-key", api_key] command += ["--api-key", api_key]
if enable_metrics:
command += ["--enable-metrics"]
if return_stdout_stderr: if return_stdout_stderr:
process = subprocess.Popen( process = subprocess.Popen(
......
...@@ -4,5 +4,5 @@ Install the dependency in CI. ...@@ -4,5 +4,5 @@ Install the dependency in CI.
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
pip install transformers==4.45.2 sentence_transformers pip install transformers==4.45.2 sentence_transformers accelerate peft
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
...@@ -16,6 +16,7 @@ suites = { ...@@ -16,6 +16,7 @@ suites = {
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
"test_metrics.py",
"test_openai_server.py", "test_openai_server.py",
"test_overlap_schedule.py", "test_overlap_schedule.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
......
import subprocess
import unittest import unittest
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
......
...@@ -6,7 +6,7 @@ import requests ...@@ -6,7 +6,7 @@ import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
) )
...@@ -15,7 +15,7 @@ from sglang.test.test_utils import ( ...@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class TestCacheReport(unittest.TestCase): class TestCacheReport(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.min_cached = 5 cls.min_cached = 5
cls.process = popen_launch_server( cls.process = popen_launch_server(
......
...@@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co ...@@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co
""" """
import os import os
import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
...@@ -11,7 +12,7 @@ import openai ...@@ -11,7 +12,7 @@ import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -21,7 +22,7 @@ from sglang.test.test_utils import ( ...@@ -21,7 +22,7 @@ from sglang.test.test_utils import (
class TestLargeMaxNewTokens(unittest.TestCase): class TestLargeMaxNewTokens(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456" cls.api_key = "sk-123456"
...@@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase): ...@@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=("--max-total-token", "1024", "--context-len", "8192"), other_args=(
"--max-total-token",
"1024",
"--context-len",
"8192",
"--decode-log-interval",
"2",
),
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ}, env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
return_stdout_stderr=(cls.stdout, cls.stderr), return_stdout_stderr=(cls.stdout, cls.stderr),
) )
cls.base_url += "/v1" cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): ...@@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
# Ensure that they are running concurrently # Ensure that they are running concurrently
pt = 0 pt = 0
while pt >= 0: while pt >= 0:
time.sleep(5)
lines = open("stderr.txt").readlines() lines = open("stderr.txt").readlines()
for line in lines[pt:]: for line in lines[pt:]:
print(line, end="", flush=True) print(line, end="", flush=True)
......
import unittest import unittest
from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
) )
TEST_MODEL = (
DEFAULT_MODEL_NAME_FOR_TEST # I used "google/gemma-2-2b-it" for testing locally
)
class TestEnableMetrics(unittest.TestCase): class TestEnableMetrics(unittest.TestCase):
def test_metrics_enabled(self): def test_metrics_enabled(self):
"""Test that metrics endpoint returns data when enabled""" """Test that metrics endpoint returns data when enabled"""
# Launch server with metrics enabled
process = popen_launch_server( process = popen_launch_server(
model=TEST_MODEL, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
base_url=DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
enable_metrics=True, other_args=["--enable-metrics"],
) )
try: try:
...@@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase):
self.assertEqual(metrics_response.status_code, 200) self.assertEqual(metrics_response.status_code, 200)
metrics_content = metrics_response.text metrics_content = metrics_response.text
print(f"{metrics_content=}")
# Verify essential metrics are present # Verify essential metrics are present
essential_metrics = [ essential_metrics = [
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
...@@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase):
self.assertIn(metric, metrics_content, f"Missing metric: {metric}") self.assertIn(metric, metrics_content, f"Missing metric: {metric}")
# Verify model name label is present and correct # Verify model name label is present and correct
expected_model_name = TEST_MODEL expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
self.assertIn(f'model_name="{expected_model_name}"', metrics_content) self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
# Verify metrics have values (not empty) # Verify metrics have values (not empty)
self.assertIn("_sum{", metrics_content) self.assertIn("_sum{", metrics_content)
...@@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase):
finally: finally:
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)
def test_metrics_disabled(self):
"""Test that metrics endpoint returns 404 when disabled"""
# Launch server with metrics disabled
process = popen_launch_server(
model=TEST_MODEL,
base_url=DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
enable_metrics=False,
)
try: if __name__ == "__main__":
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") unittest.main()
self.assertEqual(response.status_code, 200)
# Verify metrics endpoint is not available
metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
self.assertEqual(metrics_response.status_code, 404)
finally:
kill_child_process(process.pid, include_self=True)
...@@ -13,7 +13,7 @@ import openai ...@@ -13,7 +13,7 @@ import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -23,7 +23,7 @@ from sglang.test.test_utils import ( ...@@ -23,7 +23,7 @@ from sglang.test.test_utils import (
class TestOpenAIServer(unittest.TestCase): class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456" cls.api_key = "sk-123456"
cls.process = popen_launch_server( cls.process = popen_launch_server(
...@@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase):
api_key=cls.api_key, api_key=cls.api_key,
) )
cls.base_url += "/v1" cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
......
...@@ -5,7 +5,7 @@ import unittest ...@@ -5,7 +5,7 @@ import unittest
import requests import requests
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
kill_child_process, kill_child_process,
...@@ -62,7 +62,7 @@ def run_test(base_url, nodes): ...@@ -62,7 +62,7 @@ def run_test(base_url, nodes):
class TestRadixCacheFCFS(unittest.TestCase): class TestRadixCacheFCFS(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
...@@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase): ...@@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase):
class TestRadixCacheLPM(TestRadixCacheFCFS): class TestRadixCacheLPM(TestRadixCacheFCFS):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
...@@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): ...@@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
......
...@@ -9,7 +9,7 @@ import requests ...@@ -9,7 +9,7 @@ import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -19,7 +19,7 @@ from sglang.test.test_utils import ( ...@@ -19,7 +19,7 @@ from sglang.test.test_utils import (
class TestSkipTokenizerInit(unittest.TestCase): class TestSkipTokenizerInit(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
......
...@@ -10,7 +10,7 @@ import requests ...@@ -10,7 +10,7 @@ import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -20,7 +20,7 @@ from sglang.test.test_utils import ( ...@@ -20,7 +20,7 @@ from sglang.test.test_utils import (
class TestSRTEndpoint(unittest.TestCase): class TestSRTEndpoint(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
......
...@@ -11,14 +11,17 @@ from types import SimpleNamespace ...@@ -11,14 +11,17 @@ from types import SimpleNamespace
import sglang as sgl import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.few_shot_gsm8k_engine import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
class TestSRTEngine(unittest.TestCase): class TestSRTEngine(unittest.TestCase):
def test_1_engine_runtime_consistency(self): def test_1_engine_runtime_consistency(self):
prompt = "Today is a sunny day and I like" prompt = "Today is a sunny day and I like"
model_path = DEFAULT_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
...@@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase):
def test_2_engine_multiple_generate(self): def test_2_engine_multiple_generate(self):
# just to ensure there is no issue running multiple generate calls # just to ensure there is no issue running multiple generate calls
prompt = "Today is a sunny day and I like" prompt = "Today is a sunny day and I like"
model_path = DEFAULT_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
...@@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase):
# Create an LLM. # Create an LLM.
llm = sgl.Engine( llm = sgl.Engine(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
log_level="error", log_level="error",
) )
...@@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase):
def test_5_prompt_input_ids_consistency(self): def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is" prompt = "The capital of UK is"
model_path = DEFAULT_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
out1 = engine.generate(prompt, sampling_params)["text"] out1 = engine.generate(prompt, sampling_params)["text"]
......
...@@ -5,7 +5,7 @@ import requests ...@@ -5,7 +5,7 @@ import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -15,7 +15,7 @@ from sglang.test.test_utils import ( ...@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class TestUpdateWeights(unittest.TestCase): class TestUpdateWeights(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
...@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response = self.run_decode() origin_response = self.run_decode()
# update weights # update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B" new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path) ret = self.run_update_weights(new_model_path)
assert ret["success"] assert ret["success"]
...@@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response = self.run_decode() origin_response = self.run_decode()
# update weights # update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
ret = self.run_update_weights(new_model_path) ret = self.run_update_weights(new_model_path)
assert not ret["success"] assert not ret["success"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment