Unverified Commit 08ab2a16 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Json Decode && Mutl-Turns (#4)

parent f652494d
Subproject commit 00cf5f46fdbb4f1dbd9277fe3b842621c1d9e7dc
Subproject commit 88b9496e1a726ddb353eb42887cfc0ab32c99460
## Run benchmark
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```
### Dependencies
```
llama_cpp_python 0.2.19
guidance 0.1.10
vllm 0.2.5
outlines 0.0.22
```
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark
```
python3 bench_sglang.py --num-questions 10
```
### Benchmark vllm
Run llama-7b
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend vllm --num-questions 10
```
### Benchmark guidance
Run llama-7b and benchmark
```
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1
```
\ No newline at end of file
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text, read_jsonl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
from tqdm import tqdm
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
# fmt: off
def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": '
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": '
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": '
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
return s
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
states = [None] * len(arguments)
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None, regex=None):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=0,
stop=stop,
regex=regex,
)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
for _ in rets:
pass
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "json_regex_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
# fmt: off
@sgl.function
def json_warm_up(s):
s += "The information about Hogwarts is in the following JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
print(f'The warmp up json result is:\n{s["json_output"]}')
# fmt: on
# fmt: off
@sgl.function
def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Warm up
json_warm_up.run().sync()
# Run requests
tic = time.time()
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel)
for state in states:
state.sync()
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_regex_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_sglang_args_and_parse(parser)
main(args)
import json
import transformers
import wikipedia
model_path = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(model_path)
city_names = [
"los angles",
"london",
"tokyo",
"beijing",
"singapore",
"paris",
"dubai",
"sydney",
"moscow",
"rome",
"toronto",
"rio de janeiro",
"istanbul",
"berlin",
"auckland",
"buenos aires",
"mexico city",
"mumbai",
"seoul",
"bangkok",
"cairo",
"athens",
"jerusalem",
]
def get_content(city_name):
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
expected_tokens = 3000
truncate_len = int((expected_tokens / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
return truncate_content
if __name__ == "__main__":
with open("questions.jsonl", "w") as fout:
for city_name in city_names:
truncate_content = get_content(city_name)
fout.write(json.dumps({"document": truncate_content}) + "\n")
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark(short output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
```
Benchmark(long output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
```
### Benchmark vLLM
Run llama-7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Run mixtral-8x7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
```
Benchmark(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
```
Benchmark(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
```
### Benchmark guidance
Benchmark llama-7b(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
```
Benchmark llama-7b(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
```
\ No newline at end of file
import json
import time
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
import requests
from sglang.test.test_utils import add_common_other_args_and_parse
from sglang.utils import dump_state_text
from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments
def get_generate(args):
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"ignore_eos": True,
"stop": stop,
"stream": False,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
return res.json()["text"][0][len(prompt) :]
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return generate
def multi_turns(generate, qas):
s = ""
for qa in qas:
s += qa["prompt"]
s += generate(s, max_tokens=qa["new_tokens"])
return s
def main(args):
print(args)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
states = [None] * args.num_qa
generate = get_generate(args)
def get_one_answer(i):
states[i] = multi_turns(generate=generate, **multi_qas[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(multi_qas))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
for _ in rets:
pass
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turns",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_other_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
main(args)
import json
import time
from argparse import ArgumentParser
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments
@sgl.function
def multi_turns(s, qas):
for qa in qas:
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args):
print(args)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
backend = select_sglang_backend(args)
tic = time.time()
states = multi_turns.run_batch(
multi_qas, temperature=0, backend=backend, num_threads=args.parallel
)
for state in states:
state.sync()
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turns",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_sglang_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
main(args)
import random
import string
random.seed(42)
def gen_prompt(tokenizer, token_num):
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(tokenizer(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def gen_arguments(args, tokenizer):
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
for i in range(args.num_qa):
qas = multi_qas[i]["qas"]
for _ in range(args.turns):
prompt_len = random.randint(args.min_len_q, args.max_len_q)
new_tokens = random.randint(args.min_len_a, args.max_len_a)
qas.append(
{
"prompt": gen_prompt(tokenizer, prompt_len),
"new_tokens": new_tokens,
}
)
return multi_qas
......@@ -37,6 +37,7 @@ def gen(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
......@@ -60,6 +61,7 @@ def gen(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
)
......@@ -74,6 +76,7 @@ def gen_int(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
):
return SglGen(
name,
......@@ -84,6 +87,7 @@ def gen_int(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
int,
None,
)
......@@ -98,6 +102,7 @@ def gen_string(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
):
return SglGen(
name,
......@@ -108,6 +113,7 @@ def gen_string(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
str,
None,
)
......
......@@ -4,7 +4,7 @@ import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
try:
import anthropic
......@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
prompt = s.text_
ret = anthropic.Anthropic().completions.create(
......@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
prompt = s.text_
generator = anthropic.Anthropic().completions.create(
......
......@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
class BaseBackend:
......@@ -48,14 +48,14 @@ class BaseBackend:
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
......
......@@ -4,7 +4,7 @@ import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
try:
import openai
......@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
......@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
......
......@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams, SglArgument
from sglang.lang.ir import SglSamplingParams, SglArgument
from sglang.utils import encode_image_base64, find_printable_text, http_request
......@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
......@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
......
......@@ -7,7 +7,7 @@ from typing import List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
......@@ -138,7 +138,7 @@ class TGI(BaseBackend):
self,
s: StreamExecutor,
choices: List[str],
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
decision = self.retry_for_expected(
s.text_,
......@@ -152,7 +152,7 @@ class TGI(BaseBackend):
s: StreamExecutor,
max_tokens: int,
stop: Union[str, List[str]],
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
dtype: Optional[str] = None,
):
if dtype is None:
......
......@@ -6,7 +6,7 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglSamplingParams,
SglArgument,
SglConstantText,
SglExpr,
......@@ -140,7 +140,7 @@ class CompiledFunction:
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
......@@ -173,7 +173,7 @@ class CompiledFunction:
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
......
......@@ -292,7 +292,7 @@ class StreamExecutor:
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
if isinstance(other, SglConstantText):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
......@@ -332,8 +332,6 @@ class StreamExecutor:
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
......@@ -419,7 +417,7 @@ class StreamExecutor:
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
for image_path, image_base64_data in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
......@@ -480,6 +478,7 @@ class StreamExecutor:
"top_k",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"dtype",
"regex",
]:
......
......@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SamplingParams:
class SglSamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
......@@ -21,13 +21,14 @@ class SamplingParams:
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
......@@ -67,6 +68,7 @@ class SamplingParams:
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}
......@@ -98,13 +100,14 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
......@@ -112,9 +115,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
......@@ -128,6 +131,7 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
......@@ -139,7 +143,7 @@ class SglFunction:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
......@@ -147,11 +151,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
......@@ -321,12 +323,13 @@ class SglGen(SglExpr):
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
......@@ -334,6 +337,7 @@ class SglGen(SglExpr):
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
dtype=dtype,
regex=regex,
)
......
......@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
except (StopTracing, TypeError):
# Some exceptions may not be catched
pass
# Run and cache prefix
......
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
@dataclass
class BackendConfig:
extend_dependency_time: float = 0.03
GLOBAL_BACKEND_CONFIG = BackendConfig()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment