Unverified Commit 3f5c2f4c authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add an async example (#37)

parent 007eeb4e
import asyncio
from sglang import Runtime
async def generate(
engine,
prompt,
sampling_params,
):
tokenizer = engine.get_tokenizer()
messages = [
{"role": "system", "content": "You will be given question answer tasks.",},
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
stream = engine.add_request(prompt, sampling_params)
async for output in stream:
print(output, end="", flush=True)
print()
if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("runtime ready")
prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params))
runtime.shutdown()
......@@ -78,8 +78,10 @@ class OpenAI(BaseBackend):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith("ASSISTANT:"):
raise RuntimeError("This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
......
......@@ -11,6 +11,7 @@ from typing import List, Optional
# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp
import psutil
import requests
import uvicorn
......@@ -25,6 +26,7 @@ from sglang.srt.conversation import (
generate_chat_conv,
register_conv_template,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import (
......@@ -402,7 +404,7 @@ class Runtime:
):
host = "127.0.0.1"
port = alloc_usable_network_port(1)[0]
server_args = ServerArgs(
self.server_args = ServerArgs(
model_path=model_path,
tokenizer_path=tokenizer_path,
host=host,
......@@ -417,11 +419,14 @@ class Runtime:
random_seed=random_seed,
log_level=log_level,
)
self.url = server_args.url()
self.url = self.server_args.url()
self.generate_url = (
f"http://{self.server_args.host}:{self.server_args.port}/generate"
)
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
proc.start()
self.pid = proc.pid
......@@ -443,5 +448,40 @@ class Runtime:
parent.wait(timeout=5)
self.pid = None
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def add_request(
self,
prompt: str,
sampling_params,
) -> None:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
def __del__(self):
self.shutdown()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment