Unverified Commit 61f42b57 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move sgl.Runtime under sglang/lang (#2990)

parent e403d237
...@@ -9,7 +9,7 @@ from enum import Enum ...@@ -9,7 +9,7 @@ from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
import sglang as sgl import sglang as sgl
from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained.outlines_backend import build_regex_from_object
character_regex = ( character_regex = (
r"""\{\n""" r"""\{\n"""
......
...@@ -3,8 +3,8 @@ import triton_python_backend_utils as pb_utils ...@@ -3,8 +3,8 @@ import triton_python_backend_utils as pb_utils
from pydantic import BaseModel from pydantic import BaseModel
import sglang as sgl import sglang as sgl
from sglang import function, set_default_backend from sglang import function
from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained.outlines_backend import build_regex_from_object
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
......
"""
Usage:
python3 async_io.py
"""
import asyncio
from sglang import Runtime
async def generate(
engine,
prompt,
sampling_params,
):
tokenizer = engine.get_tokenizer()
messages = [
{
"role": "system",
"content": "You will be given question answer tasks.",
},
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
stream = engine.add_request(prompt, sampling_params)
async for output in stream:
print(output, end="", flush=True)
print()
if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("--- runtime ready ---\n")
prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params))
runtime.shutdown()
"""Public APIs of the language.""" """Public APIs of the language."""
import os
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -33,17 +32,13 @@ def function( ...@@ -33,17 +32,13 @@ def function(
def Runtime(*args, **kwargs): def Runtime(*args, **kwargs):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Avoid importing unnecessary dependency # Avoid importing unnecessary dependency
from sglang.srt.server import Runtime from sglang.lang.backend.runtime_endpoint import Runtime
return Runtime(*args, **kwargs) return Runtime(*args, **kwargs)
def Engine(*args, **kwargs): def Engine(*args, **kwargs):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Avoid importing unnecessary dependency # Avoid importing unnecessary dependency
from sglang.srt.server import Engine from sglang.srt.server import Engine
......
...@@ -27,7 +27,8 @@ from sglang.bench_serving import ( ...@@ -27,7 +27,8 @@ from sglang.bench_serving import (
sample_random_requests, sample_random_requests,
set_ulimit, set_ulimit,
) )
from sglang.srt.server import Engine, Runtime from sglang.lang.backend.runtime_endpoint import Runtime
from sglang.srt.server import Engine
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
......
import atexit
import json import json
import multiprocessing
import warnings import warnings
from typing import List, Optional from typing import Dict, List, Optional, Union
import aiohttp
import requests
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
...@@ -14,6 +19,9 @@ from sglang.lang.ir import ( ...@@ -14,6 +19,9 @@ from sglang.lang.ir import (
REGEX_STR, REGEX_STR,
SglSamplingParams, SglSamplingParams,
) )
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available, kill_process_tree
from sglang.utils import http_request from sglang.utils import http_request
...@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend): ...@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs): def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]] values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values) return sum(values) / len(values)
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
from sglang.srt.server import launch_server
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
proc = multiprocessing.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
"""Launch the inference server for Llava-video model."""
import json
import sys
from sglang.srt.server import launch_server, prepare_server_args
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
model_override_args = {}
model_override_args["mm_spatial_pool_stride"] = 2
model_override_args["architectures"] = ["LlavaVidForCausalLM"]
model_override_args["num_frames"] = 16
model_override_args["model_type"] = "llavavid"
if model_override_args["num_frames"] == 32:
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2
if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002
server_args.json_model_override_args = json.dumps(model_override_args)
launch_server(server_args)
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# TODO(lmzheng): make this an optional dependency
from sglang.srt.constrained.outlines_backend import build_regex_from_object
...@@ -18,6 +18,8 @@ from dataclasses import dataclass ...@@ -18,6 +18,8 @@ from dataclasses import dataclass
from threading import Event, Lock from threading import Event, Lock
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from sglang.srt.server_args import ServerArgs
@dataclass @dataclass
class CacheEntry: class CacheEntry:
...@@ -69,3 +71,22 @@ class BaseGrammarBackend: ...@@ -69,3 +71,22 @@ class BaseGrammarBackend:
def reset(self): def reset(self):
with self.cache_lock: with self.cache_lock:
self.cache.clear() self.cache.clear()
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
return grammar_backend
...@@ -34,6 +34,7 @@ import zmq ...@@ -34,6 +34,7 @@ import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...@@ -149,9 +150,7 @@ class Scheduler: ...@@ -149,9 +150,7 @@ class Scheduler:
else 1 else 1
) )
# Init inter-process communication # Distributed rank info
context = zmq.Context(2)
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info( compute_dp_attention_world_info(
...@@ -162,6 +161,8 @@ class Scheduler: ...@@ -162,6 +161,8 @@ class Scheduler:
) )
) )
# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False context, zmq.PULL, port_args.scheduler_input_ipc_name, False
...@@ -243,7 +244,7 @@ class Scheduler: ...@@ -243,7 +244,7 @@ class Scheduler:
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
# Launch worker for speculative decoding if need # Launch a worker for speculative decoding if needed
if self.spec_algorithm.is_eagle(): if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -316,6 +317,8 @@ class Scheduler: ...@@ -316,6 +317,8 @@ class Scheduler:
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
...@@ -337,28 +340,9 @@ class Scheduler: ...@@ -337,28 +340,9 @@ class Scheduler:
# Init the grammar backend for constrained generation # Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = [] self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
if server_args.grammar_backend == "outlines": self.grammar_backend = create_grammar_backend(
from sglang.srt.constrained.outlines_backend import ( server_args, self.tokenizer, self.model_config.vocab_size
OutlinesGrammarBackend, )
)
self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import (
XGrammarGrammarBackend,
)
self.grammar_backend = XGrammarGrammarBackend(
self.tokenizer, vocab_size=self.model_config.vocab_size
)
else:
raise ValueError(
f"Invalid grammar backend: {server_args.grammar_backend}"
)
else: else:
self.grammar_backend = None self.grammar_backend = None
...@@ -424,7 +408,8 @@ class Scheduler: ...@@ -424,7 +408,8 @@ class Scheduler:
}, },
) )
self._dispatcher = TypeBasedDispatcher( # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[ [
(TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request),
...@@ -480,10 +465,6 @@ class Scheduler: ...@@ -480,10 +465,6 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -506,10 +487,6 @@ class Scheduler: ...@@ -506,10 +487,6 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -517,7 +494,7 @@ class Scheduler: ...@@ -517,7 +494,7 @@ class Scheduler:
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
if self.last_batch is None: if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap scheduler. # Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event. # It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch( tmp_batch = ScheduleBatch(
reqs=None, reqs=None,
...@@ -593,7 +570,7 @@ class Scheduler: ...@@ -593,7 +570,7 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs: for recv_req in recv_reqs:
output = self._dispatcher(recv_req) output = self._request_dispatcher(recv_req)
if output is not None: if output is not None:
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_pyobj(output)
...@@ -798,15 +775,32 @@ class Scheduler: ...@@ -798,15 +775,32 @@ class Scheduler:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
if self.spec_algorithm.is_none():
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
else:
accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {accept_length:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
logger.info(msg)
if self.enable_metrics: if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
...@@ -855,16 +849,23 @@ class Scheduler: ...@@ -855,16 +849,23 @@ class Scheduler:
else: else:
self.running_batch.merge_batch(self.last_batch) self.running_batch.merge_batch(self.last_batch)
# Run prefill first if possible
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
if new_batch is not None: if new_batch is not None:
return new_batch # Run prefill first if possible
ret = new_batch
else:
# Run decode
if self.running_batch is None:
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch
# Run decode # Handle DP attention
if self.running_batch is None: if self.server_args.enable_dp_attention:
return None ret = self.prepare_dp_attn_batch(ret)
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch return ret
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue # Check if the grammar is ready in the grammar queue
...@@ -1053,6 +1054,10 @@ class Scheduler: ...@@ -1053,6 +1054,10 @@ class Scheduler:
model_worker_batch, model_worker_batch,
num_accepted_tokens, num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch) ) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
else: else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!" assert False, "batch.extend_num_tokens == 0, this is unexpected!"
......
...@@ -224,7 +224,7 @@ class TokenizerManager: ...@@ -224,7 +224,7 @@ class TokenizerManager:
}, },
) )
self._dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
(BatchStrOut, self._handle_batch_output), (BatchStrOut, self._handle_batch_output),
(BatchEmbeddingOut, self._handle_batch_output), (BatchEmbeddingOut, self._handle_batch_output),
...@@ -760,7 +760,7 @@ class TokenizerManager: ...@@ -760,7 +760,7 @@ class TokenizerManager:
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._dispatcher(recv_obj) self._result_dispatcher(recv_obj)
def _handle_batch_output( def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
......
...@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile ...@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process, run_data_parallel_controller_process,
) )
...@@ -90,7 +88,6 @@ from sglang.srt.utils import ( ...@@ -90,7 +88,6 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
delete_directory, delete_directory,
is_port_available,
kill_process_tree, kill_process_tree,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
...@@ -960,160 +957,3 @@ class Engine: ...@@ -960,160 +957,3 @@ class Engine:
obj = ResumeMemoryOccupationReqInput() obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None)) loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class above if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
...@@ -23,7 +23,7 @@ import torch.nn.functional as F ...@@ -23,7 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
...@@ -278,7 +278,7 @@ class SRTRunner: ...@@ -278,7 +278,7 @@ class SRTRunner:
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
self.runtime = Runtime( self.engine = Engine(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
...@@ -306,7 +306,7 @@ class SRTRunner: ...@@ -306,7 +306,7 @@ class SRTRunner:
top_output_logprobs = [] top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
response = self.runtime.generate( response = self.engine.generate(
prompt, prompt,
lora_path=lora_paths[i] if lora_paths else None, lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -314,7 +314,6 @@ class SRTRunner: ...@@ -314,7 +314,6 @@ class SRTRunner:
logprob_start_len=0, logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS, top_logprobs_num=NUM_TOP_LOGPROBS,
) )
response = json.loads(response)
output_strs.append(response["text"]) output_strs.append(response["text"])
top_input_logprobs.append( top_input_logprobs.append(
[ [
...@@ -343,8 +342,7 @@ class SRTRunner: ...@@ -343,8 +342,7 @@ class SRTRunner:
top_output_logprobs=top_output_logprobs, top_output_logprobs=top_output_logprobs,
) )
else: else:
response = self.runtime.encode(prompts) response = self.engine.encode(prompts)
response = json.loads(response)
if self.model_type == "embedding": if self.model_type == "embedding":
logits = [x["embedding"] for x in response] logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
...@@ -366,20 +364,18 @@ class SRTRunner: ...@@ -366,20 +364,18 @@ class SRTRunner:
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
output_strs = [] output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate( response = self.engine.generate(
prompts, prompts,
lora_path=lora_paths if lora_paths else None, lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
response = json.loads(response)
output_strs = [r["text"] for r in response] output_strs = [r["text"] for r in response]
return ModelOutput( return ModelOutput(
output_strs=output_strs, output_strs=output_strs,
) )
else: else:
response = self.runtime.encode(prompts) response = self.engine.encode(prompts)
response = json.loads(response)
if self.model_type == "embedding": if self.model_type == "embedding":
logits = [x["embedding"] for x in response] logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
...@@ -391,8 +387,8 @@ class SRTRunner: ...@@ -391,8 +387,8 @@ class SRTRunner:
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown() self.engine.shutdown()
del self.runtime del self.engine
def monkey_patch_gemma2_sdpa(): def monkey_patch_gemma2_sdpa():
......
...@@ -4,7 +4,7 @@ from enum import Enum ...@@ -4,7 +4,7 @@ from enum import Enum
from pydantic import BaseModel, constr from pydantic import BaseModel, constr
import sglang as sgl import sglang as sgl
from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained.outlines_backend import build_regex_from_object
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
......
...@@ -73,7 +73,7 @@ class TestSRTBackend(unittest.TestCase): ...@@ -73,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs # Run twice to capture more bugs
for _ in range(2): for _ in range(2):
accuracy, latency = test_hellaswag_select() accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.71) self.assertGreater(accuracy, 0.70)
def test_gen_min_new_tokens(self): def test_gen_min_new_tokens(self):
test_gen_min_new_tokens() test_gen_min_new_tokens()
......
...@@ -71,7 +71,7 @@ class TestQwen2FP8(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestQwen2FP8(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
print(metrics) print(metrics)
self.assertGreater(metrics["accuracy"], 0.8) self.assertGreater(metrics["accuracy"], 0.79)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,8 +20,8 @@ import torch ...@@ -20,8 +20,8 @@ import torch
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
MODELS = [ MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2), ("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2), ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment