Unverified Commit 61d4c939 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support stream=True in v1/completions (#49)

parent 98a3e8ef
...@@ -238,9 +238,25 @@ curl http://localhost:30000/generate \ ...@@ -238,9 +238,25 @@ curl http://localhost:30000/generate \
} }
}' }'
``` ```
Learn more about the argument format [here](docs/sampling_params.md). Learn more about the argument format [here](docs/sampling_params.md).
### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API.
```python
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response)
```
### Additional Arguments ### Additional Arguments
- Add `--tp 2` to enable tensor parallelism. - Add `--tp 2` to enable tensor parallelism.
``` ```
......
...@@ -19,7 +19,7 @@ dependencies = [ ...@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba"] "interegular", "lark", "numba", "pydantic"]
openai = ["openai>=1.0"] openai = ["openai>=1.0"]
anthropic = ["anthropic"] anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
...@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
pos = 0 pos = 0
incomplete_text = "" incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False):
if chunk: chunk = chunk.decode("utf-8")
data = json.loads(chunk.decode()) if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
text = find_printable_text(data["text"][pos:]) text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"] meta_info = data["meta_info"]
pos += len(text) pos += len(text)
......
from dataclasses import dataclass import time
from typing import Any, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
@dataclass
class CompletionRequest: class LogProbs(BaseModel):
prompt: Union[str, List[Any]] text_offset: List[int] = Field(default_factory=list)
model: str = "default" token_logprobs: List[Optional[float]] = Field(default_factory=list)
temperature: Optional[float] = 0.7 tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
"""SRT: SGLang Runtime""" """SRT: SGLang Runtime"""
import argparse
import asyncio import asyncio
import dataclasses
import json import json
import multiprocessing as mp import multiprocessing as mp
import sys import sys
...@@ -16,12 +14,19 @@ import psutil ...@@ -16,12 +14,19 @@ import psutil
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import CompletionRequest from sglang.srt.managers.openai_protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo
)
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -41,39 +46,97 @@ async def get_model_info(): ...@@ -41,39 +46,97 @@ async def get_model_info():
} }
return result return result
async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield out
@app.post("/generate") @app.post("/generate")
async def generate_request(obj: GenerateReqInput): async def generate_request(obj: GenerateReqInput):
obj.post_init() obj.post_init()
result_generator = tokenizer_manager.generate_request(obj)
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results():
async for out in result_generator: async for out in stream_generator(obj):
yield (json.dumps(out) + "\0").encode("utf-8") yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await result_generator.__anext__() ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret return ret
@app.post("/v1/completions") @app.post("/v1/completions")
async def v1_completions(obj: CompletionRequest): async def v1_completions(raw_request: Request):
assert obj.n == 1 request_json = await raw_request.json()
obj = GenerateReqInput( request = CompletionRequest(**request_json)
text=obj.prompt,
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={ sampling_params={
"temperature": obj.temperature, "temperature": request.temperature,
"max_new_tokens": obj.max_tokens, "max_new_tokens": request.max_tokens,
"stop": obj.stop, "stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
}, },
stream=request.stream,
) )
ret = await generate_request(obj) adapted_request.post_init()
return {
"choices": [{"text": ret["text"]}], if adapted_request.stream:
} async def gnerate_stream_resp():
stream_buffer = ""
async for content in stream_generator(adapted_request):
text = content["text"]
delta = text[len(stream_buffer):]
stream_buffer = text
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def launch_server(server_args, pipe_finish_writer): def launch_server(server_args, pipe_finish_writer):
......
...@@ -25,7 +25,7 @@ if __name__ == "__main__": ...@@ -25,7 +25,7 @@ if __name__ == "__main__":
"text": "The capital of France is", "text": "The capital of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1024, "max_new_tokens": 512,
}, },
"stream": True, "stream": True,
}, },
...@@ -33,9 +33,12 @@ if __name__ == "__main__": ...@@ -33,9 +33,12 @@ if __name__ == "__main__":
) )
prev = 0 prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False):
if chunk: chunk = chunk.decode("utf-8")
data = json.loads(chunk.decode()) if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip() output = data["text"].strip()
print(output[prev:], end="", flush=True) print(output[prev:], end="", flush=True)
prev = len(output) prev = len(output)
......
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import openai
def test_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response.choices[0].text)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
)
for r in response:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.created
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
args = parser.parse_args()
test_completion(args)
test_completion_stream(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment