Unverified Commit 97aa9b32 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve docs & Add JSON decode example (#121)

parent 06175286
......@@ -123,19 +123,21 @@ You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
The system will manage the state, chat template, parallelism and batching for you.
The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py)
### Control Flow
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
```python
@sgl.function
def control_flow(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
def tool_use(s, question):
s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser":
s += "The website url is" + sgl.gen("url")
elif s["tool"] == "search engine":
s += "The key word to search is" + sgl.gen("word")
```
### Parallelism
......@@ -170,6 +172,8 @@ def image_qa(s, image_file, question):
s += sgl.assistant(sgl.gen("answer", max_tokens=256)
```
See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py).
### Constrained Decoding
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
......@@ -185,6 +189,35 @@ def regular_expression_gen(s):
)
```
### JSON Decoding
```python
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
```
See also [json_decode.py](examples/usage/json_decode.py).
### Batching
Use `run_batch` to run a batch of requests with continuous batching.
......
......@@ -34,7 +34,7 @@ character_regex = (
# fmt: off
def character_gen(name, generate):
s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += generate(s, max_tokens=256, regex=character_regex)
return s
# fmt: on
......
......@@ -32,7 +32,7 @@ character_regex = (
# fmt: off
@sgl.function
def character_gen(s, name):
s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on
......
......@@ -4,13 +4,21 @@ This doc describes the sampling parameters of the SGLang Runtime.
The `/generate` endpoint accepts the following arguments in the JSON format.
```python
@dataclass
class GenerateReqInput:
# The input prompt
text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False
```
......@@ -84,3 +92,7 @@ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
prev = len(output)
print("")
```
### Multi modal
See [test_httpserver_llava.py](../test/srt/test_httpserver_llava.py).
......@@ -46,6 +46,9 @@ if __name__ == "__main__":
runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b",
tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(runtime)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print("\n========== single ==========\n")
......
"""
Usage:
python3 async_io.py
"""
import asyncio
from sglang import Runtime
......@@ -27,8 +31,8 @@ async def generate(
if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("runtime ready")
print("--- runtime ready ---\n")
prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params))
......
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python choices_logprob.py
"""
import sglang as sgl
......
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python json_decode.py
"""
from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained.json_schema import build_regex_from_object
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Wizard(BaseModel):
name: str
age: int
weapon: Weapon
@sgl.function
def pydantic_wizard_gen(s):
s += "Give me a description about a wizard in the JSON format.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0
)
def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())
def driver_pydantic_wizard_gen():
state = pydantic_wizard_gen.run()
print(state.text())
if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_character_gen()
# driver_pydantic_wizard_gen()
"""
Usage:
python3 openai_speculative.py
"""
from sglang import function, gen, set_default_backend, OpenAI
......
"""
Usage:
python3 parallel_sample.py
"""
import sglang as sgl
......
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python readme_examples.py
"""
import sglang as sgl
@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser":
s += "The website url is" + sgl.gen("url")
elif s["tool"] == "search engine":
s += "The key word to search is" + sgl.gen("word")
@sgl.function
......@@ -28,6 +34,16 @@ def tip_suggestion(s):
s += "In summary" + sgl.gen("summary")
@sgl.function
def regular_expression_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
@sgl.function
def text_qa(s, question):
s += "Q: " + question + "\n"
......@@ -46,6 +62,12 @@ def driver_tip_suggestion():
print("\n")
def driver_regex():
state = regular_expression_gen.run()
print(state.text())
print("\n")
def driver_batching():
states = text_qa.run_batch(
[
......@@ -74,9 +96,11 @@ def driver_stream():
if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
#sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_tool_use()
driver_tip_suggestion()
driver_regex()
driver_batching()
driver_stream()
from sglang import function, gen, set_default_backend, Runtime
IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen(
"answer",
temperature=0,
regex=IP_ADDR_REGEX,
)
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
set_default_backend(runtime)
state = regex_gen.run()
print(state.text())
runtime.shutdown()
"""
Usage:
python3 streaming.py
"""
import asyncio
import sglang as sgl
......
......@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle", "pillow"]
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
......@@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer):
return mask
CHAT_MODEL_NAMES = [
# GPT-4
"gpt-4",
"gpt-4-32k",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4-0314",
# GPT-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
......@@ -60,10 +47,10 @@ class OpenAI(BaseBackend):
self.tokenizer = tiktoken.encoding_for_model(model_name)
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in CHAT_MODEL_NAMES:
self.is_chat_model = True
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_template = get_chat_template("default")
......@@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs):
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
if kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
)
......
......@@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass
class GenerateReqInput:
# The input prompt
text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False
def post_init(self):
......
......@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("image.png") + question)
s += sgl.user(sgl.image("test_image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(
......
curl http://localhost:30000/generate \
-H "Content-Type: application/json" \
-d '{
"text": "Once upon a time,",
"sampling_params": {
"max_new_tokens": 16,
"temperature": 0
}
}'
......@@ -34,7 +34,7 @@ async def test_concurrent(args):
url + "/generate",
{
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"image_data": "test_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
......@@ -55,7 +55,7 @@ def test_streaming(args):
url + "/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"image_data": "test_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment