Unverified Commit d84c5e70 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Test the case when max_new_tokens is very large (#1038)

parent d7854120
...@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry from sglang.utils import find_printable_text, get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -164,8 +164,6 @@ def start_detokenizer_process( ...@@ -164,8 +164,6 @@ def start_detokenizer_process(
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
): ):
graceful_registry(inspect.currentframe().f_code.co_name)
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
except Exception: except Exception:
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
"""Request policy scheduler""" """Request policy scheduler"""
import os
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -24,9 +25,11 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch ...@@ -24,9 +25,11 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode from sglang.srt.mem_cache.radix_cache import TreeNode
# Clip the max new tokens for the request whose max_new_tokens is very large. # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative. # This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096 # Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class PolicyScheduler: class PolicyScheduler:
......
...@@ -77,7 +77,7 @@ class FileMetadata: ...@@ -77,7 +77,7 @@ class FileMetadata:
batch_storage: Dict[str, BatchResponse] = {} batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {} file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {} file_id_response: Dict[str, FileResponse] = {}
# map file id to file path in SGlang backend # map file id to file path in SGLang backend
file_id_storage: Dict[str, str] = {} file_id_storage: Dict[str, str] = {}
...@@ -335,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -335,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
} }
except Exception as e: except Exception as e:
print("error in SGlang:", e) print("error in SGLang:", e)
# Update batch status to "failed" # Update batch status to "failed"
retrieve_batch = batch_storage[batch_id] retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed" retrieve_batch.status = "failed"
......
...@@ -64,7 +64,7 @@ class ServerArgs: ...@@ -64,7 +64,7 @@ class ServerArgs:
# Other # Other
api_key: Optional[str] = None api_key: Optional[str] = None
file_storage_pth: str = "SGlang_storage" file_storage_pth: str = "SGLang_storage"
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
......
...@@ -398,6 +398,8 @@ def popen_launch_server( ...@@ -398,6 +398,8 @@ def popen_launch_server(
timeout: float, timeout: float,
api_key: Optional[str] = None, api_key: Optional[str] = None,
other_args: tuple = (), other_args: tuple = (),
env: Optional[dict] = None,
return_stdout_stderr: bool = False,
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -417,7 +419,16 @@ def popen_launch_server( ...@@ -417,7 +419,16 @@ def popen_launch_server(
if api_key: if api_key:
command += ["--api-key", api_key] command += ["--api-key", api_key]
process = subprocess.Popen(command, stdout=None, stderr=None) if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
......
...@@ -5,13 +5,15 @@ from sglang.test.test_utils import run_unittest_files ...@@ -5,13 +5,15 @@ from sglang.test.test_utils import run_unittest_files
suites = { suites = {
"minimal": [ "minimal": [
"test_chunked_prefill.py",
"test_embedding_openai_server.py",
"test_eval_accuracy.py", "test_eval_accuracy.py",
"test_large_max_new_tokens.py",
"test_openai_server.py", "test_openai_server.py",
"test_vision_openai_server.py", "test_skip_tokenizer_init.py",
"test_embedding_openai_server.py",
"test_chunked_prefill.py",
"test_torch_compile.py", "test_torch_compile.py",
"test_models_from_modelscope.py", "test_vision_openai_server.py",
"test_large_max_new_tokens.py",
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_embedding_models.py", "models/test_embedding_models.py",
"sampling/penaltylib", "sampling/penaltylib",
......
import json
import os
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
class TestOpenAIServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = "http://127.0.0.1:8157"
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
api_key=cls.api_key,
other_args=("--max-total-token", "1024"),
env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ},
return_stdout_stderr=True,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Please repeat the world 'hello' for 10000 times.",
},
],
temperature=0,
)
return response
def test_chat_completion(self):
num_requests = 4
futures = []
with ThreadPoolExecutor(16) as executor:
for i in range(num_requests):
futures.append(executor.submit(self.run_chat_completion))
all_requests_running = False
for line in iter(self.process.stderr.readline, ""):
line = str(line)
print(line, end="")
if f"#running-req: {num_requests}" in line:
all_requests_running = True
break
assert all_requests_running
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment