Unverified Commit 61f42b57 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move sgl.Runtime under sglang/lang (#2990)

parent e403d237
......@@ -9,7 +9,7 @@ from enum import Enum
from pydantic import BaseModel
import sglang as sgl
from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.outlines_backend import build_regex_from_object
character_regex = (
r"""\{\n"""
......
......@@ -3,8 +3,8 @@ import triton_python_backend_utils as pb_utils
from pydantic import BaseModel
import sglang as sgl
from sglang import function, set_default_backend
from sglang.srt.constrained import build_regex_from_object
from sglang import function
from sglang.srt.constrained.outlines_backend import build_regex_from_object
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
......
"""
Usage:
python3 async_io.py
"""
import asyncio
from sglang import Runtime
async def generate(
engine,
prompt,
sampling_params,
):
tokenizer = engine.get_tokenizer()
messages = [
{
"role": "system",
"content": "You will be given question answer tasks.",
},
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
stream = engine.add_request(prompt, sampling_params)
async for output in stream:
print(output, end="", flush=True)
print()
if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("--- runtime ready ---\n")
prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params))
runtime.shutdown()
"""Public APIs of the language."""
import os
import re
from typing import Callable, List, Optional, Union
......@@ -33,17 +32,13 @@ def function(
def Runtime(*args, **kwargs):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Avoid importing unnecessary dependency
from sglang.srt.server import Runtime
from sglang.lang.backend.runtime_endpoint import Runtime
return Runtime(*args, **kwargs)
def Engine(*args, **kwargs):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Avoid importing unnecessary dependency
from sglang.srt.server import Engine
......
......@@ -27,7 +27,8 @@ from sglang.bench_serving import (
sample_random_requests,
set_ulimit,
)
from sglang.srt.server import Engine, Runtime
from sglang.lang.backend.runtime_endpoint import Runtime
from sglang.srt.server import Engine
from sglang.srt.server_args import ServerArgs
......
import atexit
import json
import multiprocessing
import warnings
from typing import List, Optional
from typing import Dict, List, Optional, Union
import aiohttp
import requests
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
......@@ -14,6 +19,9 @@ from sglang.lang.ir import (
REGEX_STR,
SglSamplingParams,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available, kill_process_tree
from sglang.utils import http_request
......@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
from sglang.srt.server import launch_server
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
proc = multiprocessing.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
"""Launch the inference server for Llava-video model."""
import json
import sys
from sglang.srt.server import launch_server, prepare_server_args
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
model_override_args = {}
model_override_args["mm_spatial_pool_stride"] = 2
model_override_args["architectures"] = ["LlavaVidForCausalLM"]
model_override_args["num_frames"] = 16
model_override_args["model_type"] = "llavavid"
if model_override_args["num_frames"] == 32:
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2
if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002
server_args.json_model_override_args = json.dumps(model_override_args)
launch_server(server_args)
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# TODO(lmzheng): make this an optional dependency
from sglang.srt.constrained.outlines_backend import build_regex_from_object
......@@ -18,6 +18,8 @@ from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Optional, Tuple
from sglang.srt.server_args import ServerArgs
@dataclass
class CacheEntry:
......@@ -69,3 +71,22 @@ class BaseGrammarBackend:
def reset(self):
with self.cache_lock:
self.cache.clear()
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
return grammar_backend
......@@ -34,6 +34,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -149,9 +150,7 @@ class Scheduler:
else 1
)
# Init inter-process communication
context = zmq.Context(2)
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
......@@ -162,6 +161,8 @@ class Scheduler:
)
)
# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
......@@ -243,7 +244,7 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
# Launch worker for speculative decoding if need
# Launch a worker for speculative decoding if needed
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
......@@ -316,6 +317,8 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()
......@@ -337,28 +340,9 @@ class Scheduler:
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import (
OutlinesGrammarBackend,
)
self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import (
XGrammarGrammarBackend,
)
self.grammar_backend = XGrammarGrammarBackend(
self.tokenizer, vocab_size=self.model_config.vocab_size
)
else:
raise ValueError(
f"Invalid grammar backend: {server_args.grammar_backend}"
)
self.grammar_backend = create_grammar_backend(
server_args, self.tokenizer, self.model_config.vocab_size
)
else:
self.grammar_backend = None
......@@ -424,7 +408,8 @@ class Scheduler:
},
)
self._dispatcher = TypeBasedDispatcher(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
......@@ -480,10 +465,6 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
......@@ -506,10 +487,6 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
......@@ -517,7 +494,7 @@ class Scheduler:
result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap scheduler.
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
......@@ -593,7 +570,7 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
output = self._dispatcher(recv_req)
output = self._request_dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
......@@ -798,15 +775,32 @@ class Scheduler:
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
if self.spec_algorithm.is_none():
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
else:
accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {accept_length:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
......@@ -855,16 +849,23 @@ class Scheduler:
else:
self.running_batch.merge_batch(self.last_batch)
# Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
return new_batch
# Run prefill first if possible
ret = new_batch
else:
# Run decode
if self.running_batch is None:
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch
# Run decode
if self.running_batch is None:
return None
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch
# Handle DP attention
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
return ret
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
......@@ -1053,6 +1054,10 @@ class Scheduler:
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
......
......@@ -224,7 +224,7 @@ class TokenizerManager:
},
)
self._dispatcher = TypeBasedDispatcher(
self._result_dispatcher = TypeBasedDispatcher(
[
(BatchStrOut, self._handle_batch_output),
(BatchEmbeddingOut, self._handle_batch_output),
......@@ -760,7 +760,7 @@ class TokenizerManager:
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._dispatcher(recv_obj)
self._result_dispatcher(recv_obj)
def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
......
......@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
......@@ -90,7 +88,6 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
delete_directory,
is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
......@@ -960,160 +957,3 @@ class Engine:
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class above if you want to do normal offline processing.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
......@@ -23,7 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime
from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
......@@ -278,7 +278,7 @@ class SRTRunner:
):
self.model_type = model_type
self.is_generation = model_type == "generation"
self.runtime = Runtime(
self.engine = Engine(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
......@@ -306,7 +306,7 @@ class SRTRunner:
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = self.runtime.generate(
response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
......@@ -314,7 +314,6 @@ class SRTRunner:
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
......@@ -343,8 +342,7 @@ class SRTRunner:
top_output_logprobs=top_output_logprobs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
......@@ -366,20 +364,18 @@ class SRTRunner:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate(
response = self.engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
......@@ -391,8 +387,8 @@ class SRTRunner:
return self
def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown()
del self.runtime
self.engine.shutdown()
del self.engine
def monkey_patch_gemma2_sdpa():
......
......@@ -4,7 +4,7 @@ from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.outlines_backend import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
......
......@@ -73,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.71)
self.assertGreater(accuracy, 0.70)
def test_gen_min_new_tokens(self):
test_gen_min_new_tokens()
......
......@@ -71,7 +71,7 @@ class TestQwen2FP8(unittest.TestCase):
metrics = run_eval(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.8)
self.assertGreater(metrics["accuracy"], 0.79)
if __name__ == "__main__":
......
......@@ -20,8 +20,8 @@ import torch
from sglang.test.runners import HFRunner, SRTRunner
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2),
("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
]
TORCH_DTYPES = [torch.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment