Unverified Commit 61d4c939 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support stream=True in v1/completions (#49)

parent 98a3e8ef
......@@ -238,9 +238,25 @@ curl http://localhost:30000/generate \
}
}'
```
Learn more about the argument format [here](docs/sampling_params.md).
### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API.
```python
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response)
```
### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
......
......@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba"]
"interegular", "lark", "numba", "pydantic"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
......@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)
......
from dataclasses import dataclass
from typing import Any, List, Optional, Union
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
@dataclass
class CompletionRequest:
prompt: Union[str, List[Any]]
model: str = "default"
temperature: Optional[float] = 0.7
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
"""SRT: SGLang Runtime"""
import argparse
import asyncio
import dataclasses
import json
import multiprocessing as mp
import sys
......@@ -16,12 +14,19 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import CompletionRequest
from sglang.srt.managers.openai_protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo
)
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -41,39 +46,97 @@ async def get_model_info():
}
return result
async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield out
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
result_generator = tokenizer_manager.generate_request(obj)
if obj.stream:
async def stream_results():
async for out in result_generator:
yield (json.dumps(out) + "\0").encode("utf-8")
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await result_generator.__anext__()
return ret
ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
@app.post("/v1/completions")
async def v1_completions(obj: CompletionRequest):
assert obj.n == 1
obj = GenerateReqInput(
text=obj.prompt,
async def v1_completions(raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"temperature": obj.temperature,
"max_new_tokens": obj.max_tokens,
"stop": obj.stop,
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
stream=request.stream,
)
ret = await generate_request(obj)
return {
"choices": [{"text": ret["text"]}],
}
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
stream_buffer = ""
async for content in stream_generator(adapted_request):
text = content["text"]
delta = text[len(stream_buffer):]
stream_buffer = text
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def launch_server(server_args, pipe_finish_writer):
......
......@@ -25,7 +25,7 @@ if __name__ == "__main__":
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
"max_new_tokens": 512,
},
"stream": True,
},
......@@ -33,9 +33,12 @@ if __name__ == "__main__":
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
......
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import openai
def test_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response.choices[0].text)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
)
for r in response:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.created
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
args = parser.parse_args()
test_completion(args)
test_completion_stream(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment