Unverified Commit 95c4e0df authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Format Benchmark Code (#399)

parent 19818b9c
...@@ -3,7 +3,6 @@ import json ...@@ -3,7 +3,6 @@ import json
import transformers import transformers
import wikipedia import wikipedia
name = "meta-llama/Llama-2-7b-chat-hf" name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name) t = transformers.AutoTokenizer.from_pretrained(name)
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"] city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
...@@ -20,7 +19,9 @@ for city_name in city_names: ...@@ -20,7 +19,9 @@ for city_name in city_names:
truncate_tokens = t.encode(truncate_content) truncate_tokens = t.encode(truncate_content)
# Count token # Count token
print(f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}") print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
with open("questions.jsonl", "a") as fout: with open("questions.jsonl", "a") as fout:
fout.write(json.dumps({"document": truncate_content}) + "\n") fout.write(json.dumps({"document": truncate_content}) + "\n")
import argparse import argparse
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
import json import json
from functools import partial
import os import os
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tiktoken import tiktoken
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
choices = ["A", "B", "C", "D"] choices = ["A", "B", "C", "D"]
...@@ -25,18 +30,22 @@ def format_subject(subject): ...@@ -25,18 +30,22 @@ def format_subject(subject):
s += " " + entry s += " " + entry
return s return s
def format_example(df, idx, include_answer=True): def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0] prompt = df.iloc[idx, 0]
k = df.shape[1] - 2 k = df.shape[1] - 2
for j in range(k): for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1]) prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:" prompt += "\nAnswer:"
if include_answer: if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1]) prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt return prompt
def gen_prompt(train_df, subject, k=-1): def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject)) prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
format_subject(subject)
)
if k == -1: if k == -1:
k = train_df.shape[0] k = train_df.shape[0]
for i in range(k): for i in range(k):
...@@ -63,7 +72,7 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -63,7 +72,7 @@ def evaluate(args, subject, dev_df, test_df):
prompt = train_prompt + prompt_end prompt = train_prompt + prompt_end
prompts.append(prompt) prompts.append(prompt)
label = test_df.iloc[i, test_df.shape[1]-1] label = test_df.iloc[i, test_df.shape[1] - 1]
labels.append(label) labels.append(label)
preds = [None] * len(prompts) preds = [None] * len(prompts)
...@@ -82,17 +91,24 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -82,17 +91,24 @@ def evaluate(args, subject, dev_df, test_df):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url, stop=None) call_generate = partial(call_generate_srt_raw, url=url, stop=None)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
if model_initialized is None: if model_initialized is None:
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
model_initialized = model model_initialized = model
else: else:
model = model_initialized model = model_initialized
def call_generate(prompt, temperature, max_tokens): def call_generate(prompt, temperature, max_tokens):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=0) model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0)
)
return out["answer"] return out["answer"]
# warmup # warmup
...@@ -100,8 +116,10 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -100,8 +116,10 @@ def evaluate(args, subject, dev_df, test_df):
elif args.backend == "lmql": elif args.backend == "lmql":
import lmql import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}") model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model) @lmql.query(model=model)
async def program(question): async def program(question):
...@@ -112,6 +130,7 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -112,6 +130,7 @@ def evaluate(args, subject, dev_df, test_df):
async def call_generate(prompt, temperature, max_tokens): async def call_generate(prompt, temperature, max_tokens):
return await program(question=prompt, temperature=temperature) return await program(question=prompt, temperature=temperature)
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
...@@ -119,8 +138,7 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -119,8 +138,7 @@ def evaluate(args, subject, dev_df, test_df):
if args.backend != "lmql": if args.backend != "lmql":
# Use thread pool # Use thread pool
def get_one_answer(i): def get_one_answer(i):
pred = call_generate(prompts[i], temperature=0, pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)
max_tokens=max_tokens)
preds[i] = pred.strip()[0] preds[i] = pred.strip()[0]
tic = time.time() tic = time.time()
...@@ -135,12 +153,11 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -135,12 +153,11 @@ def evaluate(args, subject, dev_df, test_df):
async def batched_call(batch_size): async def batched_call(batch_size):
for i in range(0, len(prompts), batch_size): for i in range(0, len(prompts), batch_size):
tasks = [] tasks = []
for p in prompts[i:i+batch_size]: for p in prompts[i : i + batch_size]:
tasks.append(call_generate(p, tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))
temperature=0, max_tokens=max_tokens))
rets = await asyncio.gather(*tasks) rets = await asyncio.gather(*tasks)
for j in range(len(rets)): for j in range(len(rets)):
preds[i+j] = rets[j].strip()[0] preds[i + j] = rets[j].strip()[0]
tic = time.time() tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel)) asyncio.run(batched_call(batch_size=args.parallel))
...@@ -151,22 +168,35 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -151,22 +168,35 @@ def evaluate(args, subject, dev_df, test_df):
acc = np.mean(cors) acc = np.mean(cors)
cors = np.array(cors) cors = np.array(cors)
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format( print(
acc, latency, len(prompts), subject)) "Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
acc, latency, len(prompts), subject
)
)
return cors, acc, latency return cors, acc, latency
def main(args): def main(args):
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]) subjects = sorted(
[
f.split("_test.csv")[0]
for f in os.listdir(os.path.join(args.data_dir, "test"))
if "_test.csv" in f
]
)
all_cors = [] all_cors = []
all_latencies = [] all_latencies = []
num_requests = 0 num_requests = 0
for subject in tqdm(subjects[:args.nsub]): for subject in tqdm(subjects[: args.nsub]):
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain] dev_df = pd.read_csv(
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None) os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
)[: args.ntrain]
test_df = pd.read_csv(
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
)
cors, acc, latency = evaluate(args, subject, dev_df, test_df) cors, acc, latency = evaluate(args, subject, dev_df, test_df)
all_cors.append(cors) all_cors.append(cors)
...@@ -191,7 +221,7 @@ def main(args): ...@@ -191,7 +221,7 @@ def main(args):
"other": { "other": {
"nsub": args.nsub, "nsub": args.nsub,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -7,8 +7,11 @@ import numpy as np ...@@ -7,8 +7,11 @@ import numpy as np
import pandas as pd import pandas as pd
import tiktoken import tiktoken
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
choices = ["A", "B", "C", "D"] choices = ["A", "B", "C", "D"]
...@@ -22,24 +25,29 @@ def format_subject(subject): ...@@ -22,24 +25,29 @@ def format_subject(subject):
s += " " + entry s += " " + entry
return s return s
def format_example(df, idx, include_answer=True): def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0] prompt = df.iloc[idx, 0]
k = df.shape[1] - 2 k = df.shape[1] - 2
for j in range(k): for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1]) prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:" prompt += "\nAnswer:"
if include_answer: if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1]) prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt return prompt
def gen_prompt(train_df, subject, k=-1): def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject)) prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
format_subject(subject)
)
if k == -1: if k == -1:
k = train_df.shape[0] k = train_df.shape[0]
for i in range(k): for i in range(k):
prompt += format_example(train_df, i) prompt += format_example(train_df, i)
return prompt return prompt
def evaluate(args, subject, dev_df, test_df): def evaluate(args, subject, dev_df, test_df):
prompts = [] prompts = []
labels = [] labels = []
...@@ -54,7 +62,7 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -54,7 +62,7 @@ def evaluate(args, subject, dev_df, test_df):
prompt_end = format_example(test_df, i, include_answer=False) prompt_end = format_example(test_df, i, include_answer=False)
prompts.append(prompt_end) prompts.append(prompt_end)
label = test_df.iloc[i, test_df.shape[1]-1] label = test_df.iloc[i, test_df.shape[1] - 1]
labels.append(label) labels.append(label)
arguments = [{"question": p} for p in prompts] arguments = [{"question": p} for p in prompts]
...@@ -66,11 +74,14 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -66,11 +74,14 @@ def evaluate(args, subject, dev_df, test_df):
import sglang as sgl import sglang as sgl
if args.backend.startswith("gpt-"): if args.backend.startswith("gpt-"):
@sgl.function @sgl.function
def few_shot_mmlu(s, examples, question): def few_shot_mmlu(s, examples, question):
s += sgl.user(examples + question) s += sgl.user(examples + question)
s += sgl.assistant(sgl.gen("answer")) s += sgl.assistant(sgl.gen("answer"))
else: else:
@sgl.function @sgl.function
def few_shot_mmlu(s, examples, question): def few_shot_mmlu(s, examples, question):
s += examples + question + sgl.gen("answer") s += examples + question + sgl.gen("answer")
...@@ -84,32 +95,50 @@ def evaluate(args, subject, dev_df, test_df): ...@@ -84,32 +95,50 @@ def evaluate(args, subject, dev_df, test_df):
tic = time.time() tic = time.time()
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch( states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
arguments, temperature=0, max_new_tokens=1, arguments,
backend=backend, num_threads=args.parallel) temperature=0,
preds = [s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" max_new_tokens=1,
for s in states] backend=backend,
num_threads=args.parallel,
)
preds = [
s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
]
latency = time.time() - tic latency = time.time() - tic
cors = [pred == label for pred, label in zip(preds, labels)] cors = [pred == label for pred, label in zip(preds, labels)]
acc = np.mean(cors) acc = np.mean(cors)
cors = np.array(cors) cors = np.array(cors)
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format( print(
acc, latency, len(prompts), subject)) "Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
acc, latency, len(prompts), subject
)
)
return cors, acc, latency return cors, acc, latency
def main(args): def main(args):
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]) subjects = sorted(
[
f.split("_test.csv")[0]
for f in os.listdir(os.path.join(args.data_dir, "test"))
if "_test.csv" in f
]
)
all_cors = [] all_cors = []
all_latencies = [] all_latencies = []
num_requests = 0 num_requests = 0
for subject in tqdm(subjects[:args.nsub]): for subject in tqdm(subjects[: args.nsub]):
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain] dev_df = pd.read_csv(
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None) os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
)[: args.ntrain]
test_df = pd.read_csv(
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
)
cors, acc, latency = evaluate(args, subject, dev_df, test_df) cors, acc, latency = evaluate(args, subject, dev_df, test_df)
all_cors.append(cors) all_cors.append(cors)
...@@ -134,7 +163,7 @@ def main(args): ...@@ -134,7 +163,7 @@ def main(args):
"other": { "other": {
"nsub": args.nsub, "nsub": args.nsub,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import os import os
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from fastchat.model import get_conversation_template from fastchat.model import get_conversation_template
import requests
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt,
call_generate_vllm,
)
def load_questions(filename): def load_questions(filename):
...@@ -38,7 +43,7 @@ def write_answers(filename, model_id, questions, answers): ...@@ -38,7 +43,7 @@ def write_answers(filename, model_id, questions, answers):
def main(args): def main(args):
questions = load_questions(args.question_file) questions = load_questions(args.question_file)
questions = (questions * 10)[:args.num_questions] questions = (questions * 10)[: args.num_questions]
max_tokens = 256 max_tokens = 256
model_id = "llama-2-chat" model_id = "llama-2-chat"
...@@ -67,9 +72,8 @@ def main(args): ...@@ -67,9 +72,8 @@ def main(args):
conv.append_message(conv.roles[0], q) conv.append_message(conv.roles[0], q)
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() prompt = conv.get_prompt()
output = call_generate(prompt, output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()
temperature=0, max_tokens=max_tokens).strip()
cur_answers.append(output) cur_answers.append(output)
conv.update_last_message(output) conv.update_last_message(output)
...@@ -102,7 +106,7 @@ def main(args): ...@@ -102,7 +106,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -5,7 +5,10 @@ import time ...@@ -5,7 +5,10 @@ import time
import uuid import uuid
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
def load_questions(filename): def load_questions(filename):
...@@ -44,10 +47,9 @@ def answer_mt_bench(s, question_1, question_2): ...@@ -44,10 +47,9 @@ def answer_mt_bench(s, question_1, question_2):
def main(args): def main(args):
# Construct prompts # Construct prompts
questions = load_questions(args.question_file)[:args.num_questions] questions = load_questions(args.question_file)[: args.num_questions]
arguments = [ arguments = [
{"question_1": q["turns"][0], "question_2": q["turns"][1]} {"question_1": q["turns"][0], "question_2": q["turns"][1]} for q in questions
for q in questions
] ]
# Select backend # Select backend
...@@ -83,7 +85,7 @@ def main(args): ...@@ -83,7 +85,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import ast import ast
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -44,14 +49,20 @@ def multi_chain_gsm8k(question, num_chains, call_generate): ...@@ -44,14 +49,20 @@ def multi_chain_gsm8k(question, num_chains, call_generate):
comps = [] comps = []
for i in range(num_chains): for i in range(num_chains):
comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains], comps.append(
max_tokens=256, temperature=0.3, stop="Question")) call_generate(
s + "Answer: " + prompt_lib[i % num_chains],
max_tokens=256,
temperature=0.3,
stop="Question",
)
)
s += "Answer: To answer this question, here are some possible solutions. " s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n" s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains): for i in range(num_chains):
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n" s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += call_generate(s, max_tokens=16, temperature=0, stop=None) s += call_generate(s, max_tokens=16, temperature=0, stop=None)
return s return s
...@@ -64,7 +75,7 @@ def main(args): ...@@ -64,7 +75,7 @@ def main(args):
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -82,16 +93,28 @@ def main(args): ...@@ -82,16 +93,28 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop): def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"] return out["answer"]
#def multi_chain_gsm8k(question, num_chains, call_generate): # def multi_chain_gsm8k(question, num_chains, call_generate):
# s = model + "Question: " + question + "\n" # s = model + "Question: " + question + "\n"
# comps = [] # comps = []
...@@ -108,8 +131,10 @@ def main(args): ...@@ -108,8 +131,10 @@ def main(args):
elif args.backend == "lmql": elif args.backend == "lmql":
import lmql import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}") model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model) @lmql.query(model=model)
async def program(question): async def program(question):
...@@ -128,8 +153,7 @@ def main(args): ...@@ -128,8 +153,7 @@ def main(args):
if args.backend != "lmql": if args.backend != "lmql":
# Use thread pool # Use thread pool
def get_one_answer(i): def get_one_answer(i):
answer = multi_chain_gsm8k(questions[i], args.num_chains, answer = multi_chain_gsm8k(questions[i], args.num_chains, call_generate)
call_generate)
states[i] = answer states[i] = answer
tic = time.time() tic = time.time()
...@@ -144,12 +168,18 @@ def main(args): ...@@ -144,12 +168,18 @@ def main(args):
async def batched_call(batch_size): async def batched_call(batch_size):
for i in range(0, len(questions), batch_size): for i in range(0, len(questions), batch_size):
tasks = [] tasks = []
for q in questions[i:i+batch_size]: for q in questions[i : i + batch_size]:
tasks.append(call_generate(few_shot_examples + q, tasks.append(
temperature=0, max_tokens=256, stop="Question")) call_generate(
few_shot_examples + q,
temperature=0,
max_tokens=256,
stop="Question",
)
)
rets = await asyncio.gather(*tasks) rets = await asyncio.gather(*tasks)
for j in range(len(rets)): for j in range(len(rets)):
states[i+j] = get_answer_value(rets[j]) states[i + j] = get_answer_value(rets[j])
tic = time.time() tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel)) asyncio.run(batched_call(batch_size=args.parallel))
...@@ -180,7 +210,7 @@ def main(args): ...@@ -180,7 +210,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -5,16 +5,19 @@ import re ...@@ -5,16 +5,19 @@ import re
import time import time
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -37,12 +40,12 @@ def main(args): ...@@ -37,12 +40,12 @@ def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
# Construct prompts # Construct prompts
#k = args.num_shot # k = args.num_shot
#few_shot_examples = get_few_shot_examples(lines, k) # few_shot_examples = get_few_shot_examples(lines, k)
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -59,21 +62,24 @@ def main(args): ...@@ -59,21 +62,24 @@ def main(args):
@sgl.function @sgl.function
def multi_chain_gsm8k(s, question): def multi_chain_gsm8k(s, question):
s += "Question: " + question + "\n" s += "Question: " + question + "\n"
#s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question", # s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
# temperature=0) # temperature=0)
#return # return
forks = s.fork(num_chains) forks = s.fork(num_chains)
for i in range(num_chains): for i in range(num_chains):
forks[i] += ("Answer: " + prompt_lib[i % num_chains] + forks[i] += (
sgl.gen(f"chain", max_tokens=256, temperature=0.3, stop="Question")) "Answer: "
+ prompt_lib[i % num_chains]
+ sgl.gen("chain", max_tokens=256, temperature=0.3, stop="Question")
)
forks.join() forks.join()
s += "Answer: To answer this question, here are some possible solutions. " s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n" s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains): for i in range(num_chains):
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n" s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += sgl.gen("answer", max_tokens=16) s += sgl.gen("answer", max_tokens=16)
##################################### #####################################
...@@ -86,7 +92,12 @@ def main(args): ...@@ -86,7 +92,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = multi_chain_gsm8k.run_batch( states = multi_chain_gsm8k.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
preds = [] preds = []
...@@ -114,7 +125,7 @@ def main(args): ...@@ -114,7 +125,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
USER_PREFIX = "[INST] " USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]" USER_SUFFIX = " [/INST]"
...@@ -25,7 +28,11 @@ def multi_document_qa(docs, question, generate): ...@@ -25,7 +28,11 @@ def multi_document_qa(docs, question, generate):
s += "".join(docs) s += "".join(docs)
s += "\nDocuments end." s += "\nDocuments end."
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.") s += (
"\n\nBased on the above documents, please answer this question:\n"
+ question
+ "\nAnswer in three words or fewer."
)
s += USER_SUFFIX s += USER_SUFFIX
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
answer = generate(s, max_tokens=16, stop=None) answer = generate(s, max_tokens=16, stop=None)
...@@ -42,11 +49,13 @@ def main(args): ...@@ -42,11 +49,13 @@ def main(args):
if args.backend == "guidance": if args.backend == "guidance":
num_docs = 7 # due to OOM num_docs = 7 # due to OOM
for i in range(len(l["questions"][:args.num_questions])): for i in range(len(l["questions"][: args.num_questions])):
arguments.append({ arguments.append(
"docs": l["documents"][:num_docs], {
"question": l["questions"][i], "docs": l["documents"][:num_docs],
}) "question": l["questions"][i],
}
)
labels.append(l["answers"][i]) labels.append(l["answers"][i])
states = [None] * len(arguments) states = [None] * len(arguments)
...@@ -61,13 +70,20 @@ def main(args): ...@@ -61,13 +70,20 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0) generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000) model = models.LlamaCpp(
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
n_gpu_layers=-1,
n_ctx=11000,
)
def generate(prompt, max_tokens, stop): def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=0, stop=stop) model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"] return out["answer"]
# warmup # warmup
...@@ -113,7 +129,7 @@ def main(args): ...@@ -113,7 +129,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -2,10 +2,12 @@ import argparse ...@@ -2,10 +2,12 @@ import argparse
import json import json
import time import time
import numpy as np
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
from sglang.utils import read_jsonl, dump_state_text add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function @sgl.function
...@@ -19,7 +21,11 @@ def multi_document_qa(s, docs, question): ...@@ -19,7 +21,11 @@ def multi_document_qa(s, docs, question):
forks.join("concate_and_append") forks.join("concate_and_append")
s += "\nDocuments end." s += "\nDocuments end."
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.") s += (
"\n\nBased on the above documents, please answer this question:\n"
+ question
+ "\nAnswer in three words or fewer."
)
s += sgl.user_end() s += sgl.user_end()
s += sgl.assistant(sgl.gen("answer", max_tokens=16)) s += sgl.assistant(sgl.gen("answer", max_tokens=16))
...@@ -29,11 +35,13 @@ def main(args): ...@@ -29,11 +35,13 @@ def main(args):
l = lines[0] l = lines[0]
arguments = [] arguments = []
labels = [] labels = []
for i in range(len(l["questions"][:args.num_questions])): for i in range(len(l["questions"][: args.num_questions])):
arguments.append({ arguments.append(
"docs": l["documents"][:10], {
"question": l["questions"][i], "docs": l["documents"][:10],
}) "question": l["questions"][i],
}
)
labels.append(l["answers"][i]) labels.append(l["answers"][i])
# Select backend # Select backend
...@@ -43,10 +51,11 @@ def main(args): ...@@ -43,10 +51,11 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = multi_document_qa.run_batch( states = multi_document_qa.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True) arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
print([s["answer"] for s in states]) print([s["answer"] for s in states])
correct = 0 correct = 0
for s, label in zip(states, labels): for s, label in zip(states, labels):
...@@ -71,7 +80,7 @@ def main(args): ...@@ -71,7 +80,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -3,7 +3,8 @@ import json ...@@ -3,7 +3,8 @@ import json
import transformers import transformers
content = "\n".join( content = "\n".join(
open("llama2.txt", 'r', encoding='utf-8', errors='ignore').readlines()) open("llama2.txt", "r", encoding="utf-8", errors="ignore").readlines()
)
content = content.replace("\n\n", "\n") content = content.replace("\n\n", "\n")
# Count token # Count token
...@@ -35,30 +36,35 @@ for i, s in enumerate(segments): ...@@ -35,30 +36,35 @@ for i, s in enumerate(segments):
# Dump # Dump
with open("questions.jsonl", "w") as fout: with open("questions.jsonl", "w") as fout:
fout.write(json.dumps({ fout.write(
"documents": segments[:30], json.dumps(
"questions": [ {
"What is the name of the fine-tuned LLMs?", "documents": segments[:30],
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?", "questions": [
"What is the number of parameters in the largest Llama 2 model?", "What is the name of the fine-tuned LLMs?",
"What is the batch size of fine-tuning?", "Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
"Where can we find the details of potential data contamination?", "What is the number of parameters in the largest Llama 2 model?",
"What is the full name of MPT?", "What is the batch size of fine-tuning?",
"What is the power consumption of RSC in Watt?", "Where can we find the details of potential data contamination?",
"How many tokens of data do they train on?", "What is the full name of MPT?",
"Which model's release is delayed due to a lack of time to sufficiently red team?", "What is the power consumption of RSC in Watt?",
"Which activation function is used in Llama?" "How many tokens of data do they train on?",
], "Which model's release is delayed due to a lack of time to sufficiently red team?",
"answers": [ "Which activation function is used in Llama?",
"Llama 2 Chat", ],
"1", "answers": [
"70 B", "Llama 2 Chat",
"64", "1",
"A 6", "70 B",
"MosaicML", "64",
"400", "A 6",
"2 trillion", "MosaicML",
"34 B", "400",
"SwiGLU", "2 trillion",
], "34 B",
}) + "\n") "SwiGLU",
],
}
)
+ "\n"
)
...@@ -4,12 +4,12 @@ from argparse import ArgumentParser ...@@ -4,12 +4,12 @@ from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import requests import requests
from sglang.test.test_utils import add_common_other_args_and_parse from data_gen import gen_arguments
from sglang.utils import dump_state_text
from tqdm import tqdm from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments from sglang.test.test_utils import add_common_other_args_and_parse
from sglang.utils import dump_state_text
def get_generate(args): def get_generate(args):
...@@ -61,7 +61,7 @@ def multi_turns(generate, qas): ...@@ -61,7 +61,7 @@ def multi_turns(generate, qas):
s = "" s = ""
for qa in qas: for qa in qas:
s += qa["prompt"] s += qa["prompt"]
s += generate(s, max_tokens=qa["new_tokens"]) s += generate(s, max_tokens=qa["new_tokens"])
return s return s
......
...@@ -2,22 +2,22 @@ import json ...@@ -2,22 +2,22 @@ import json
import time import time
from argparse import ArgumentParser from argparse import ArgumentParser
from data_gen import gen_arguments
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import dump_state_text from sglang.utils import dump_state_text
from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments
@sgl.function @sgl.function
def multi_turns(s, qas): def multi_turns(s, qas):
for qa in qas: for qa in qas:
s += qa["prompt"] s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args): def main(args):
...@@ -29,7 +29,11 @@ def main(args): ...@@ -29,7 +29,11 @@ def main(args):
tic = time.time() tic = time.time()
states = multi_turns.run_batch( states = multi_turns.run_batch(
multi_qas, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
) )
latency = time.time() - tic latency = time.time() - tic
......
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_other_args_and_parse, add_common_other_args_and_parse,
call_generate_lightllm, call_generate_lightllm,
call_generate_vllm,
call_generate_srt_raw, call_generate_srt_raw,
call_generate_vllm,
) )
from sglang.utils import read_jsonl, dump_state_text from sglang.utils import dump_state_text, read_jsonl
def get_prompt(question): def get_prompt(question):
...@@ -83,16 +84,15 @@ Action 2: Search[Leonid Levin] ...@@ -83,16 +84,15 @@ Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes] Action 3: Finish[yes]
""" + question) """
+ question
)
return prompt return prompt
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{ arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
"question": k,
"triplets": v
} for l in lines for k, v in l.items()]
states = [] states = []
...@@ -107,7 +107,7 @@ def main(args): ...@@ -107,7 +107,7 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp( model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
...@@ -116,12 +116,16 @@ def main(args): ...@@ -116,12 +116,16 @@ def main(args):
) )
def call_generate(prompt, temperature, max_tokens, stop): def call_generate(prompt, temperature, max_tokens, stop):
out = (model + prompt + gen( out = (
name="result", model
max_tokens=max_tokens, + prompt
temperature=temperature, + gen(
stop=stop, name="result",
)) max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["result"] return out["result"]
# warmup # warmup
...@@ -137,15 +141,23 @@ def main(args): ...@@ -137,15 +141,23 @@ def main(args):
for i in range(1, len(triplets) + 2): for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":" prompt += "Thought " + str(i) + ":"
states.append(prompt) states.append(prompt)
answer = call_generate(prompt, answer = call_generate(
max_tokens=200, prompt, max_tokens=200, temperature=0, stop="Observation"
temperature=0, )
stop="Observation")
if i > len(triplets): if i > len(triplets):
break break
prompt += (triplets[i - 1]["thought"] + "\nAction " + str(i) + prompt += (
":" + triplets[i - 1]["action"] + "\nObservation " + triplets[i - 1]["thought"]
str(i) + ":" + triplets[i - 1]["observation"] + "\n") + "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
states.append(answer) states.append(answer)
......
...@@ -7,7 +7,7 @@ from sglang.test.test_utils import ( ...@@ -7,7 +7,7 @@ from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import read_jsonl, dump_state_text from sglang.utils import dump_state_text, read_jsonl
@sgl.function @sgl.function
...@@ -79,7 +79,9 @@ Action 2: Search[Leonid Levin] ...@@ -79,7 +79,9 @@ Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes] Action 3: Finish[yes]
""" + question) """
+ question
)
for i in range(1, len(triplets) + 2): for i in range(1, len(triplets) + 2):
s += "Thought " + str(i) + ":" s += "Thought " + str(i) + ":"
# NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation. # NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation.
...@@ -90,17 +92,23 @@ Action 3: Finish[yes] ...@@ -90,17 +92,23 @@ Action 3: Finish[yes]
# print(ss[0]["thought_action"]) # print(ss[0]["thought_action"])
if i > len(triplets): if i > len(triplets):
break break
s += (triplets[i - 1]["thought"] + "\nAction " + str(i) + ":" + s += (
triplets[i - 1]["action"] + "\nObservation " + str(i) + ":" + triplets[i - 1]["thought"]
triplets[i - 1]["observation"] + "\n") + "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{ arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
"question": k,
"triplets": v
} for l in lines for k, v in l.items()]
# Select backend # Select backend
backend = select_sglang_backend(args) backend = select_sglang_backend(args)
...@@ -108,11 +116,12 @@ def main(args): ...@@ -108,11 +116,12 @@ def main(args):
states = [] states = []
tic = time.time() tic = time.time()
states = webthink.run_batch(arguments, states = webthink.run_batch(
temperature=0, arguments,
num_threads=args.parallel, temperature=0,
progress_bar=True, num_threads=args.parallel,
) progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
......
import argparse import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
number = 5 number = 5
def expand_tip(topic, tip, generate): def expand_tip(topic, tip, generate):
s = ( s = (
"""Please expand a tip for a topic into a detailed paragraph. """Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy Topic: staying healthy
Tip: Regular Exercise Tip: Regular Exercise
...@@ -30,14 +33,23 @@ Topic: writing a blog post ...@@ -30,14 +33,23 @@ Topic: writing a blog post
Tip: structure your content effectively Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement. Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:") Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
return generate(s, max_tokens=128, stop=["\n\n"]) return generate(s, max_tokens=128, stop=["\n\n"])
def suggest_tips(topic, generate): def suggest_tips(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n" s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n" s += "USER: Give some tips for " + topic + ".\n"
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n") s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
tips = [] tips = []
for i in range(1, 1 + number): for i in range(1, 1 + number):
...@@ -49,12 +61,12 @@ def suggest_tips(topic, generate): ...@@ -49,12 +61,12 @@ def suggest_tips(topic, generate):
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips] paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number): for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i-1] + "\n" s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
return s return s
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
states = [None] * len(lines) states = [None] * len(lines)
# Select backend # Select backend
...@@ -68,13 +80,20 @@ def main(args): ...@@ -68,13 +80,20 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0) generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop): def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=0, stop=stop) model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"] return out["answer"]
# warmup # warmup
...@@ -111,7 +130,7 @@ def main(args): ...@@ -111,7 +130,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -2,11 +2,12 @@ import argparse ...@@ -2,11 +2,12 @@ import argparse
import json import json
import time import time
import numpy as np
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
from sglang.utils import read_jsonl, dump_state_text add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
number = 5 number = 5
...@@ -14,7 +15,7 @@ number = 5 ...@@ -14,7 +15,7 @@ number = 5
@sgl.function @sgl.function
def expand_tip(s, topic, tip): def expand_tip(s, topic, tip):
s += ( s += (
"""Please expand a tip for a topic into a detailed paragraph. """Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy Topic: staying healthy
Tip: Regular Exercise Tip: Regular Exercise
...@@ -28,7 +29,12 @@ Topic: writing a blog post ...@@ -28,7 +29,12 @@ Topic: writing a blog post
Tip: structure your content effectively Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement. Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:") Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0) s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
...@@ -36,7 +42,11 @@ Topic: """ + topic + "\nTip: " + tip + "\nParagraph:") ...@@ -36,7 +42,11 @@ Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
def suggest_tips(s, topic): def suggest_tips(s, topic):
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n" s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n" s += "USER: Give some tips for " + topic + ".\n"
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n") s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
paragraphs = [] paragraphs = []
for i in range(1, 1 + number): for i in range(1, 1 + number):
...@@ -44,14 +54,12 @@ def suggest_tips(s, topic): ...@@ -44,14 +54,12 @@ def suggest_tips(s, topic):
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"])) paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
for i in range(1, 1 + number): for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i-1]["paragraph"] + "\n" s += f"Tip {i}:" + paragraphs[i - 1]["paragraph"] + "\n"
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [ arguments = [{"topic": l["topic"]} for l in lines]
{"topic": l["topic"]} for l in lines
]
# Select backend # Select backend
sgl.set_default_backend(select_sglang_backend(args)) sgl.set_default_backend(select_sglang_backend(args))
...@@ -59,7 +67,8 @@ def main(args): ...@@ -59,7 +67,8 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = suggest_tips.run_batch( states = suggest_tips.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True) arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
...@@ -78,7 +87,7 @@ def main(args): ...@@ -78,7 +87,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import ast import ast
import asyncio
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import re import re
import time import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -47,35 +51,56 @@ temp = 0.001 ...@@ -47,35 +51,56 @@ temp = 0.001
def propose_plan(s, question, num_branches, call_generate): def propose_plan(s, question, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX) USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate): def execute_plan(s, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX) USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate): def reflect_solution(s, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX) USER_PREFIX
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def get_final_answer(s, num_branches, call_generate): def get_final_answer(s, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + USER_SUFFIX) USER_PREFIX
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
...@@ -107,7 +132,7 @@ def main(args): ...@@ -107,7 +132,7 @@ def main(args):
num_branches = 2 num_branches = 2
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -124,20 +149,40 @@ def main(args): ...@@ -124,20 +149,40 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n): def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1: if n == 1:
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"] return out["answer"]
else: else:
rets = [] rets = []
for i in range(n): for i in range(n):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"]) rets.append(out["answer"])
return rets return rets
...@@ -146,6 +191,7 @@ def main(args): ...@@ -146,6 +191,7 @@ def main(args):
# Run requests # Run requests
states = [None] * len(questions) states = [None] * len(questions)
def get_one_answer(i): def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate) states[i] = tree_search(**arguments[i], call_generate=call_generate)
...@@ -188,7 +234,7 @@ def main(args): ...@@ -188,7 +234,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import ast import ast
from collections import Counter
import json import json
import re import re
import time import time
from collections import Counter
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
import sglang as sgl
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -40,7 +43,9 @@ temp = 0.001 ...@@ -40,7 +43,9 @@ temp = 0.001
def propose_plan(s, question, num_branches): def propose_plan(s, question, num_branches):
s += sgl.user( s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question) """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks return forks
...@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches): ...@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches):
def execute_plan(s, num_branches): def execute_plan(s, num_branches):
s += sgl.user( s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""") """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks return forks
...@@ -56,7 +62,8 @@ def execute_plan(s, num_branches): ...@@ -56,7 +62,8 @@ def execute_plan(s, num_branches):
def reflect_solution(s, num_branches): def reflect_solution(s, num_branches):
s += sgl.user( s += sgl.user(
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""") """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks return forks
...@@ -64,13 +71,13 @@ def reflect_solution(s, num_branches): ...@@ -64,13 +71,13 @@ def reflect_solution(s, num_branches):
def get_final_answer(s, num_branches): def get_final_answer(s, num_branches):
s += sgl.user( s += sgl.user(
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""") """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
return forks return forks
@sgl.function @sgl.function
def tree_search(s, question, num_branches): def tree_search(s, question, num_branches):
plan_forks = propose_plan(s, question, num_branches) plan_forks = propose_plan(s, question, num_branches)
...@@ -93,6 +100,7 @@ def tree_search(s, question, num_branches): ...@@ -93,6 +100,7 @@ def tree_search(s, question, num_branches):
return solutions return solutions
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
...@@ -100,7 +108,7 @@ def main(args): ...@@ -100,7 +108,7 @@ def main(args):
num_branches = 2 num_branches = 2
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -112,7 +120,12 @@ def main(args): ...@@ -112,7 +120,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = tree_search.run_batch( states = tree_search.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
answers_text = [] answers_text = []
for s in states: for s in states:
...@@ -144,7 +157,7 @@ def main(args): ...@@ -144,7 +157,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import ast import ast
import asyncio
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import re import re
import time import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -47,27 +51,43 @@ temp = 0.3 ...@@ -47,27 +51,43 @@ temp = 0.3
def propose_plan(s, question, num_branches, call_generate): def propose_plan(s, question, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX) USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate): def execute_plan(s, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX) USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate): def reflect_solution(s, num_branches, call_generate):
s += (USER_PREFIX + s += (
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX) USER_PREFIX
+ """Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) comps = call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps] return [s + comp + ASSISTANT_SUFFIX for comp in comps]
...@@ -92,7 +112,7 @@ def main(args): ...@@ -92,7 +112,7 @@ def main(args):
num_branches = 3 num_branches = 3
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -109,25 +129,46 @@ def main(args): ...@@ -109,25 +129,46 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n): def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1: if n == 1:
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"] return out["answer"]
else: else:
rets = [] rets = []
for i in range(n): for i in range(n):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"]) rets.append(out["answer"])
return rets return rets
# Run requests # Run requests
states = [None] * len(questions) states = [None] * len(questions)
def get_one_answer(i): def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate) states[i] = tree_search(**arguments[i], call_generate=call_generate)
...@@ -170,7 +211,7 @@ def main(args): ...@@ -170,7 +211,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import ast import ast
from collections import Counter
import json import json
import re import re
import time import time
from collections import Counter
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
import sglang as sgl
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -40,7 +43,9 @@ temp = 0.3 ...@@ -40,7 +43,9 @@ temp = 0.3
def propose_plan(s, question, num_branches): def propose_plan(s, question, num_branches):
s += sgl.user( s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question) """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks return forks
...@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches): ...@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches):
def execute_plan(s, num_branches): def execute_plan(s, num_branches):
s += sgl.user( s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""") """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks return forks
...@@ -56,7 +62,8 @@ def execute_plan(s, num_branches): ...@@ -56,7 +62,8 @@ def execute_plan(s, num_branches):
def reflect_solution(s, num_branches): def reflect_solution(s, num_branches):
s += sgl.user( s += sgl.user(
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""") """Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
)
forks = s.fork(num_branches) forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp)) forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks return forks
...@@ -90,7 +97,7 @@ def main(args): ...@@ -90,7 +97,7 @@ def main(args):
num_branches = 3 num_branches = 3
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(lines[i]["question"]) questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -102,7 +109,12 @@ def main(args): ...@@ -102,7 +109,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = tree_search.run_batch( states = tree_search.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
answers_text = [] answers_text = []
for s in states: for s in states:
...@@ -134,7 +146,7 @@ def main(args): ...@@ -134,7 +146,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment