Unverified Commit 97aa9b32 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve docs & Add JSON decode example (#121)

parent 06175286
...@@ -123,19 +123,21 @@ You can implement your prompt flow in a function decorated by `sgl.function`. ...@@ -123,19 +123,21 @@ You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`. You can then invoke the function with `run` or `run_batch`.
The system will manage the state, chat template, parallelism and batching for you. The system will manage the state, chat template, parallelism and batching for you.
The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py)
### Control Flow ### Control Flow
You can use any Python code within the function body, including control flow, nested function calls, and external libraries. You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
```python ```python
@sgl.function @sgl.function
def control_flow(s, question): def tool_use(s, question):
s += "To answer this question: " + question + ", " s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". " s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
if s["tool"] == "calculator": if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression") s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser": elif s["tool"] == "search engine":
s += "The website url is" + sgl.gen("url") s += "The key word to search is" + sgl.gen("word")
``` ```
### Parallelism ### Parallelism
...@@ -170,6 +172,8 @@ def image_qa(s, image_file, question): ...@@ -170,6 +172,8 @@ def image_qa(s, image_file, question):
s += sgl.assistant(sgl.gen("answer", max_tokens=256) s += sgl.assistant(sgl.gen("answer", max_tokens=256)
``` ```
See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py).
### Constrained Decoding ### Constrained Decoding
Use `regex` to specify a regular expression as a decoding constraint. Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models. This is only supported for local models.
...@@ -185,6 +189,35 @@ def regular_expression_gen(s): ...@@ -185,6 +189,35 @@ def regular_expression_gen(s):
) )
``` ```
### JSON Decoding
```python
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
```
See also [json_decode.py](examples/usage/json_decode.py).
### Batching ### Batching
Use `run_batch` to run a batch of requests with continuous batching. Use `run_batch` to run a batch of requests with continuous batching.
......
...@@ -34,7 +34,7 @@ character_regex = ( ...@@ -34,7 +34,7 @@ character_regex = (
# fmt: off # fmt: off
def character_gen(name, generate): def character_gen(name, generate):
s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += generate(s, max_tokens=256, regex=character_regex) s += generate(s, max_tokens=256, regex=character_regex)
return s return s
# fmt: on # fmt: on
......
...@@ -32,7 +32,7 @@ character_regex = ( ...@@ -32,7 +32,7 @@ character_regex = (
# fmt: off # fmt: off
@sgl.function @sgl.function
def character_gen(s, name): def character_gen(s, name):
s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex) s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on # fmt: on
......
...@@ -4,13 +4,21 @@ This doc describes the sampling parameters of the SGLang Runtime. ...@@ -4,13 +4,21 @@ This doc describes the sampling parameters of the SGLang Runtime.
The `/generate` endpoint accepts the following arguments in the JSON format. The `/generate` endpoint accepts the following arguments in the JSON format.
```python ```python
@dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt
text: Union[List[str], str] text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False stream: bool = False
``` ```
...@@ -84,3 +92,7 @@ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): ...@@ -84,3 +92,7 @@ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
prev = len(output) prev = len(output)
print("") print("")
``` ```
### Multi modal
See [test_httpserver_llava.py](../test/srt/test_httpserver_llava.py).
...@@ -46,6 +46,9 @@ if __name__ == "__main__": ...@@ -46,6 +46,9 @@ if __name__ == "__main__":
runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b", runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b",
tokenizer_path="llava-hf/llava-1.5-7b-hf") tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(runtime) sgl.set_default_backend(runtime)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request # Run a single request
print("\n========== single ==========\n") print("\n========== single ==========\n")
......
"""
Usage:
python3 async_io.py
"""
import asyncio import asyncio
from sglang import Runtime from sglang import Runtime
...@@ -27,8 +31,8 @@ async def generate( ...@@ -27,8 +31,8 @@ async def generate(
if __name__ == "__main__": if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("runtime ready") print("--- runtime ready ---\n")
prompt = "Who is Alan Turing?" prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128} sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params)) asyncio.run(generate(runtime, prompt, sampling_params))
......
""" """
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python choices_logprob.py
""" """
import sglang as sgl import sglang as sgl
......
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python json_decode.py
"""
from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained.json_schema import build_regex_from_object
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Wizard(BaseModel):
name: str
age: int
weapon: Weapon
@sgl.function
def pydantic_wizard_gen(s):
s += "Give me a description about a wizard in the JSON format.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0
)
def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())
def driver_pydantic_wizard_gen():
state = pydantic_wizard_gen.run()
print(state.text())
if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_character_gen()
# driver_pydantic_wizard_gen()
"""
Usage:
python3 openai_speculative.py
"""
from sglang import function, gen, set_default_backend, OpenAI from sglang import function, gen, set_default_backend, OpenAI
......
"""
Usage:
python3 parallel_sample.py
"""
import sglang as sgl import sglang as sgl
......
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python readme_examples.py
"""
import sglang as sgl import sglang as sgl
@sgl.function @sgl.function
def tool_use(s, question): def tool_use(s, question):
s += "To answer this question: " + question + ", " s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". " s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
if s["tool"] == "calculator": if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression") s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser": elif s["tool"] == "search engine":
s += "The website url is" + sgl.gen("url") s += "The key word to search is" + sgl.gen("word")
@sgl.function @sgl.function
...@@ -28,6 +34,16 @@ def tip_suggestion(s): ...@@ -28,6 +34,16 @@ def tip_suggestion(s):
s += "In summary" + sgl.gen("summary") s += "In summary" + sgl.gen("summary")
@sgl.function
def regular_expression_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
@sgl.function @sgl.function
def text_qa(s, question): def text_qa(s, question):
s += "Q: " + question + "\n" s += "Q: " + question + "\n"
...@@ -46,6 +62,12 @@ def driver_tip_suggestion(): ...@@ -46,6 +62,12 @@ def driver_tip_suggestion():
print("\n") print("\n")
def driver_regex():
state = regular_expression_gen.run()
print(state.text())
print("\n")
def driver_batching(): def driver_batching():
states = text_qa.run_batch( states = text_qa.run_batch(
[ [
...@@ -74,9 +96,11 @@ def driver_stream(): ...@@ -74,9 +96,11 @@ def driver_stream():
if __name__ == "__main__": if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) #sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_tool_use() driver_tool_use()
driver_tip_suggestion() driver_tip_suggestion()
driver_regex()
driver_batching() driver_batching()
driver_stream() driver_stream()
from sglang import function, gen, set_default_backend, Runtime
IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen(
"answer",
temperature=0,
regex=IP_ADDR_REGEX,
)
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
set_default_backend(runtime)
state = regex_gen.run()
print(state.text())
runtime.shutdown()
"""
Usage:
python3 streaming.py
"""
import asyncio import asyncio
import sglang as sgl import sglang as sgl
......
...@@ -20,7 +20,7 @@ dependencies = [ ...@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba", "zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle", "pillow"] "pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"] openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"] anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
...@@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer): ...@@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer):
return mask return mask
CHAT_MODEL_NAMES = [ INSTRUCT_MODEL_NAMES = [
# GPT-4 "gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-32k",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4-0314",
# GPT-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
] ]
...@@ -60,10 +47,10 @@ class OpenAI(BaseBackend): ...@@ -60,10 +47,10 @@ class OpenAI(BaseBackend):
self.tokenizer = tiktoken.encoding_for_model(model_name) self.tokenizer = tiktoken.encoding_for_model(model_name)
self.logit_bias_int = create_logit_bias_int(self.tokenizer) self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in CHAT_MODEL_NAMES: if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = True
else:
self.is_chat_model = False self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_template = get_chat_template("default") self.chat_template = get_chat_template("default")
...@@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs): ...@@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs):
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs): def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
try: try:
if is_chat: if is_chat:
if kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create( generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs messages=prompt, stream=True, **kwargs
) )
......
...@@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams ...@@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt
text: Union[List[str], str] text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False stream: bool = False
def post_init(self): def post_init(self):
......
...@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True): ...@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa(): def test_image_qa():
@sgl.function @sgl.function
def image_qa(s, question): def image_qa(s, question):
s += sgl.user(sgl.image("image.png") + question) s += sgl.user(sgl.image("test_image.png") + question)
s += sgl.assistant(sgl.gen("answer")) s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run( state = image_qa.run(
......
curl http://localhost:30000/generate \
-H "Content-Type: application/json" \
-d '{
"text": "Once upon a time,",
"sampling_params": {
"max_new_tokens": 16,
"temperature": 0
}
}'
...@@ -34,7 +34,7 @@ async def test_concurrent(args): ...@@ -34,7 +34,7 @@ async def test_concurrent(args):
url + "/generate", url + "/generate",
{ {
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png", "image_data": "test_image.png",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
...@@ -55,7 +55,7 @@ def test_streaming(args): ...@@ -55,7 +55,7 @@ def test_streaming(args):
url + "/generate", url + "/generate",
json={ json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png", "image_data": "test_image.png",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 128, "max_new_tokens": 128,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment