Unverified Commit 14522e6a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize Benchmark (#381)

parent 183df472
......@@ -5,16 +5,11 @@ import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
......@@ -67,6 +62,32 @@ def multi_chain_gsm8k(question, num_chains, call_generate):
return s
async def multi_chain_gsm8k_async(question, num_chains, call_generate):
s = "Question: " + question + "\n"
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
# stop="Question", temperature=0)
# return s
comps = []
for i in range(num_chains):
comps.append(
await call_generate(
s + "Answer: " + prompt_lib[i % num_chains],
max_tokens=256,
temperature=0.3,
stop="Question",
)
)
s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains):
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += await call_generate(s, max_tokens=16, temperature=0, stop=None)
return s
def main(args):
lines = read_jsonl(args.data_path)
......@@ -83,71 +104,7 @@ def main(args):
states = [None] * len(labels)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
# def multi_chain_gsm8k(question, num_chains, call_generate):
# s = model + "Question: " + question + "\n"
# comps = []
# for i in range(num_chains):
# comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
# max_tokens=256, temperature=0.3, stop="Question"))
# s += "Answer: To answer this question, here are some possible solutions. "
# s += "After considering all of them, I will do a majority vote.\n\n"
# for i in range(num_chains):
# s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
# s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
# return call_generate(s, max_tokens=16, temperature=0, stop=None)
elif args.backend == "lmql":
import lmql
model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens, stop):
return await program(question=prompt, temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
call_generate = get_call_generate(args)
# Run requests
if args.backend != "lmql":
......@@ -158,31 +115,35 @@ def main(args):
tic = time.time()
if args.parallel == 1:
for i in range(len(questions)):
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q in questions[i : i + batch_size]:
tasks.append(
call_generate(
few_shot_examples + q,
temperature=0,
max_tokens=256,
stop="Question",
)
)
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
states[i + j] = get_answer_value(rets[j])
async def get_one_answer_asyncio(i):
answer = await multi_chain_gsm8k_async(
questions[i], args.num_chains, call_generate
)
states[i] = answer
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
loop = asyncio.get_event_loop()
batches = [
list(range(i, min(i + args.parallel, len(questions))))
for i in range(0, len(questions), args.parallel)
]
for bt in tqdm(batches):
tasks = [get_one_answer_asyncio(k) for k in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic
preds = []
......
......@@ -22,7 +22,7 @@ python3 bench_other.py --backend vllm --num-questions 64
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
```
......
......@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
USER_PREFIX = "[INST] "
......@@ -60,40 +55,11 @@ def main(args):
states = [None] * len(arguments)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
n_gpu_layers=-1,
n_ctx=11000,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = multi_document_qa(generate=generate, **arguments[i])
states[i] = multi_document_qa(generate=call_generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
......@@ -101,7 +67,13 @@ def main(args):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(labels))))
list(
tqdm(
executor.map(get_one_answer, list(range(len(labels)))),
total=len(labels),
)
)
latency = time.time() - tic
# Compute accuracy
......
......@@ -56,11 +56,11 @@ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
Benchmark Llama-7B (short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
Benchmark Llama-7B (long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long
```
......@@ -2,61 +2,16 @@ import json
import time
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import requests
from data_gen import gen_arguments
from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer
from sglang.test.test_utils import add_common_other_args_and_parse
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text
def get_generate(args):
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"ignore_eos": True,
"stop": stop,
"stream": False,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
return res.json()["text"][0][len(prompt) :]
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return generate
def multi_turns(generate, qas):
s = ""
for qa in qas:
......@@ -75,10 +30,10 @@ def main(args):
states = [None] * args.num_qa
generate = get_generate(args)
call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i):
states[i] = multi_turns(generate=generate, **multi_qas[i])
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
tic = time.time()
if args.parallel == 1:
......@@ -86,7 +41,12 @@ def main(args):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(multi_qas)))),
total=len(multi_qas),
)
)
for _ in rets:
pass
......
......@@ -24,5 +24,11 @@ python3 bench_other.py --num-questions 100 --backend vllm
### Benchmark guidance
```
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 100 --backend lmql --parallel 1
```
\ No newline at end of file
......@@ -2,17 +2,10 @@ import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
......@@ -97,42 +90,7 @@ def main(args):
states = []
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="result",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["result"]
# warmup
call_generate("Hello,", 1.0, 8, ".")
else:
raise ValueError(f"Invalid backend: {args.backend}")
call_generate = get_call_generate(args)
def run_single_agent(argument):
question = argument["question"]
......@@ -161,13 +119,60 @@ def main(args):
states.append(answer)
async def run_single_agent_async(argument):
question = argument["question"]
triplets = argument["triplets"]
prompt = get_prompt(question)
for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":"
states.append(prompt)
answer = await call_generate(
prompt, max_tokens=200, temperature=0, stop="Observation", max_len=4096
)
if i > len(triplets):
break
prompt += (
triplets[i - 1]["thought"]
+ "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
states.append(answer)
tic = time.time()
if args.parallel == 1:
for arg in tqdm(arguments):
run_single_agent(arg)
if args.backend != "lmql":
if args.parallel == 1:
for arg in tqdm(arguments):
run_single_agent(arg)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(run_single_agent, arguments), total=len(arguments)
)
)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_single_agent, arguments)
import asyncio
loop = asyncio.get_event_loop()
batches = [
[] for _ in range((len(arguments) + args.parallel - 1) // args.parallel)
]
for i, arg in enumerate(arguments):
batches[i // args.parallel].append(arg)
for bt in tqdm(batches):
tasks = [run_single_agent_async(arg) for arg in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
......
......@@ -23,5 +23,11 @@ python3 bench_other.py --backend vllm --num-questions 64
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --backend lmql --num-questions 32 --parallel 1
```
\ No newline at end of file
......@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
number = 5
......@@ -70,48 +65,43 @@ def main(args):
states = [None] * len(lines)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
if args.backend != "lmql":
def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], call_generate)
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(lines)))),
total=len(lines),
)
)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(lines))))
import asyncio
from lmql_funcs import suggest_tips_async
async def get_one_answer_async(i):
states[i] = await suggest_tips_async(lines[i]["topic"], call_generate)
batches = []
for i in range(0, len(lines), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
loop = asyncio.get_event_loop()
for batch in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in batch])
)
latency = time.time() - tic
# Compute accuracy
......
number = 5
async def expand_tip_async(topic, tip, generate):
s = (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
return await generate(s, max_tokens=128, stop="\n\n")
async def suggest_tips_async(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
tips = []
for i in range(1, 1 + number):
s += f"{i}."
# NOTE: stop is different due to lmql does not support a list of stop tokens
tip = await generate(s, max_tokens=24, stop=".\n")
s += tip + ".\n"
tips.append(tip)
paragraphs = [await expand_tip_async(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
return s
......@@ -41,5 +41,11 @@ python3 bench_other.py --num-questions 32 --backend lightllm
### Benchmark guidance
```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 8 --backend lmql --parallel 1
```
......@@ -5,17 +5,11 @@ import re
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
......@@ -139,69 +133,50 @@ def main(args):
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
else:
rets = []
for i in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"])
return rets
# warmup
call_generate("Hello,", 1.0, 8, ".", 1)
call_generate = get_call_generate(args)
# Run requests
states = [None] * len(questions)
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
if args.backend != "lmql":
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
import asyncio
from lmql_funcs import tree_search_async
async def get_one_answer_async(i):
states[i] = await tree_search_async(
**arguments[i], call_generate=call_generate
)
batches = [
[] for _ in range((len(questions) + args.parallel - 1) // args.parallel)
]
for i in range(len(questions)):
batches[i // args.parallel].append(i)
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
tasks = [get_one_answer_async(k) for k in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic
answers_text = []
......
from bench_other import (
ASSISTANT_PREFIX,
ASSISTANT_SUFFIX,
USER_PREFIX,
USER_SUFFIX,
temp,
)
async def propose_plan_async(s, question, num_branches, call_generate):
s += (
USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def execute_plan_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def reflect_solution_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def get_final_answer_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def tree_search_async(question, num_branches, call_generate):
plan_forks = await propose_plan_async("", question, num_branches, call_generate)
sol_states = []
for plan in plan_forks:
forks = await execute_plan_async(plan, num_branches, call_generate)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = await reflect_solution_async(sol, num_branches, call_generate)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
ans = await get_final_answer_async(sol, num_branches, call_generate)
solutions.append(ans)
return solutions
......@@ -39,5 +39,5 @@ python3 bench_other.py --num-questions 32 --backend lightllm
### Benchmark guidance
```
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
......@@ -5,17 +5,11 @@ import re
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999
......@@ -119,52 +113,7 @@ def main(args):
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
else:
rets = []
for i in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"])
return rets
call_generate = get_call_generate(args)
# Run requests
states = [None] * len(questions)
......@@ -178,7 +127,13 @@ def main(args):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic
answers_text = []
......
"""Common utilities for testing and benchmarking"""
import asyncio
from functools import partial
import numpy as np
import requests
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
from sglang.srt.utils import get_exception_traceback
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
data = {
"inputs": prompt,
"parameters": {
......@@ -23,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
return pred
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
......@@ -41,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
def call_generate_outlines(
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
......@@ -60,7 +70,9 @@ def call_generate_outlines(
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"text": prompt,
"sampling_params": {
......@@ -76,7 +88,71 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
return pred
def call_select_lightllm(context, choices, url):
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
assert model is not None
from guidance import gen
rets = []
for _ in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
regex=regex,
)
)
rets.append(out["answer"])
return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
......@@ -91,7 +167,9 @@ def call_select_lightllm(context, choices, url):
return np.argmax(scores)
def call_select_vllm(context, choices, url):
def call_select_vllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
......@@ -113,6 +191,31 @@ def call_select_vllm(context, choices, url):
"""
def call_select_guidance(context, choices, model=None):
assert model is not None
from guidance import select
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
def add_common_other_args_and_parse(parser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
......@@ -121,8 +224,17 @@ def add_common_other_args_and_parse(parser):
"--backend",
type=str,
required=True,
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
choices=[
"vllm",
"outlines",
"lightllm",
"guidance",
"lmql",
"srt-raw",
"llama.cpp",
],
)
parser.add_argument("--n-ctx", type=int, default=4096)
parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
......@@ -132,6 +244,7 @@ def add_common_other_args_and_parse(parser):
if args.port is None:
default_port = {
"vllm": 21000,
"outlines": 21000,
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
......@@ -161,3 +274,77 @@ def select_sglang_backend(args):
else:
raise ValueError(f"Invalid backend: {args.backend}")
return backend
def _get_call_generate(args):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".")
return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_select = partial(call_select_guidance, model=model)
call_select("Hello,", ["world", "earth"])
return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
try:
return call_generate(*args, **kwargs)
except Exception:
print("Exception in call_generate:\n" + get_exception_traceback())
raise
return func
def get_call_select(args):
call_select = _get_call_select(args)
def func(*args, **kwargs):
try:
return call_select(*args, **kwargs)
except Exception:
print("Exception in call_select:\n" + get_exception_traceback())
raise
return func
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment