Unverified Commit 95c4e0df authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Format Benchmark Code (#399)

parent 19818b9c
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Adapted from Adapted from
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9 https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
""" """
import argparse import argparse
import dspy import dspy
...@@ -37,29 +38,41 @@ class RAG(dspy.Module): ...@@ -37,29 +38,41 @@ class RAG(dspy.Module):
def main(args): def main(args):
#lm = dspy.OpenAI(model='gpt-3.5-turbo') # lm = dspy.OpenAI(model='gpt-3.5-turbo')
if args.backend == "tgi": if args.backend == "tgi":
lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, lm = dspy.HFClientTGI(
url="http://localhost") model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
elif args.backend == "sglang": elif args.backend == "sglang":
lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, lm = dspy.HFClientSGLang(
url="http://localhost") model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
elif args.backend == "vllm": elif args.backend == "vllm":
lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, lm = dspy.HFClientVLLM(
url="http://localhost") model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') colbertv2_wiki17_abstracts = dspy.ColBERTv2(
url="http://20.102.90.50:2017/wiki17_abstracts"
)
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts) dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
# Load the dataset. # Load the dataset.
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, dataset = HotPotQA(
test_size=0) train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
)
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata. # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
trainset = [x.with_inputs('question') for x in dataset.train] trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev] devset = [x.with_inputs("question") for x in dataset.dev]
print(len(trainset), len(devset)) print(len(trainset), len(devset))
...@@ -72,8 +85,12 @@ def main(args): ...@@ -72,8 +85,12 @@ def main(args):
print(f"Answer: {dev_example.answer}") print(f"Answer: {dev_example.answer}")
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}") print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}") print(
print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}") f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
)
print(
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
)
# Define the predictor. # Define the predictor.
generate_answer = dspy.Predict(BasicQA) generate_answer = dspy.Predict(BasicQA)
...@@ -101,10 +118,14 @@ def main(args): ...@@ -101,10 +118,14 @@ def main(args):
retrieve = dspy.Retrieve(k=3) retrieve = dspy.Retrieve(k=3)
topK_passages = retrieve(dev_example.question).passages topK_passages = retrieve(dev_example.question).passages
print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n') print(
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
"-" * 30,
"\n",
)
for idx, passage in enumerate(topK_passages): for idx, passage in enumerate(topK_passages):
print(f'{idx+1}]', passage, '\n') print(f"{idx+1}]", passage, "\n")
retrieve("When was the first FIFA World Cup held?").passages[0] retrieve("When was the first FIFA World Cup held?").passages[0]
...@@ -137,7 +158,12 @@ def main(args): ...@@ -137,7 +158,12 @@ def main(args):
from dspy.evaluate.evaluate import Evaluate from dspy.evaluate.evaluate import Evaluate
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below. # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5) evaluate_on_hotpotqa = Evaluate(
devset=devset,
num_threads=args.num_threads,
display_progress=True,
display_table=5,
)
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric. # Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
metric = dspy.evaluate.answer_exact_match metric = dspy.evaluate.answer_exact_match
...@@ -149,8 +175,9 @@ if __name__ == "__main__": ...@@ -149,8 +175,9 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int) parser.add_argument("--port", type=int)
parser.add_argument("--num-threads", type=int, default=32) parser.add_argument("--num-threads", type=int, default=32)
parser.add_argument("--dev-size", type=int, default=150) parser.add_argument("--dev-size", type=int, default=150)
parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"], parser.add_argument(
default="sglang") "--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
)
args = parser.parse_args() args = parser.parse_args()
if args.port is None: if args.port is None:
......
...@@ -122,16 +122,36 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs ...@@ -122,16 +122,36 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs
* Must be one of the "Area options," verbatim. * Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe} For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---""" ---"""
s += (persona_name + " lives in " + living_sector + " that has " + s += (
living_sector_areas + ".\n") persona_name
s += (persona_name + " is currently in " + current_sector + " that has " + + " lives in "
current_sector_areas + ".\n") + living_sector
+ " that has "
+ living_sector_areas
+ ".\n"
)
s += (
persona_name
+ " is currently in "
+ current_sector
+ " that has "
+ current_sector_areas
+ ".\n"
)
s += daily_plan + ".\n" s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n" s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n""" * Must be one of the "Area options," verbatim.\n"""
s += (persona_name + " is " + current_action + ". For " + next_action + s += (
", " + persona_name + " should go to the following area: {") persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ " should go to the following area: {"
)
s += sgl.gen(name="Location", max_tokens=10, stop="}") s += sgl.gen(name="Location", max_tokens=10, stop="}")
...@@ -162,22 +182,43 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs ...@@ -162,22 +182,43 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs
* Must be one of the "Area options," verbatim. * Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe} For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---""" ---"""
s += (persona_name + " lives in " + living_sector + " that has " + s += (
living_sector_areas + ".\n") persona_name
s += (persona_name + " is currently in " + current_sector + " that has " + + " lives in "
current_sector_areas + ".\n") + living_sector
+ " that has "
+ living_sector_areas
+ ".\n"
)
s += (
persona_name
+ " is currently in "
+ current_sector
+ " that has "
+ current_sector_areas
+ ".\n"
)
s += daily_plan + ".\n" s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n" s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n""" * Must be one of the "Area options," verbatim.\n"""
s += (persona_name + " is " + current_action + ". For " + next_action + s += (
", " + persona_name + " should go to the following area: {") persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ " should go to the following area: {"
)
return {"prompt": s, "max_tokens": 10, "stop": "}"} return {"prompt": s, "max_tokens": 10, "stop": "}"}
@sgl.function @sgl.function
def action_location_object(s, persona_name, target_sector, target_sector_areas, def action_location_object(
current_action, next_action): s, persona_name, target_sector, target_sector_areas, current_action, next_action
):
s += """ s += """
Jane Anderson is in kitchen in Jane Anderson's house. Jane Anderson is in kitchen in Jane Anderson's house.
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom} Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
...@@ -191,20 +232,34 @@ Stay in the current area if the activity can be done there. Never go into other ...@@ -191,20 +232,34 @@ Stay in the current area if the activity can be done there. Never go into other
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe: For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe} Answer: {cafe}
---""" ---"""
s += (persona_name + " is going to " + target_sector + s += (
" that has the following areas: {" + target_sector_areas + "}\n") persona_name
+ " is going to "
+ target_sector
+ " that has the following areas: {"
+ target_sector_areas
+ "}\n"
)
s += """* Stay in the current area if the activity can be done there. s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary.""" * NEVER go into other people's rooms unless necessary."""
s += (persona_name + " is " + current_action + ". For " + next_action + s += (
", " + persona_name + "should go to the following area in " + persona_name
target_sector) + " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ "should go to the following area in "
+ target_sector
)
s += " (MUST pick one of {" + target_sector_areas + "}):\n" s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}") s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
def action_location_object_prompt(persona_name, target_sector, def action_location_object_prompt(
target_sector_areas, current_action, persona_name, target_sector, target_sector_areas, current_action, next_action
next_action): ):
s = "" s = ""
s += """ s += """
Jane Anderson is in kitchen in Jane Anderson's house. Jane Anderson is in kitchen in Jane Anderson's house.
...@@ -219,13 +274,27 @@ Stay in the current area if the activity can be done there. Never go into other ...@@ -219,13 +274,27 @@ Stay in the current area if the activity can be done there. Never go into other
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe: For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe} Answer: {cafe}
---""" ---"""
s += (persona_name + " is going to " + target_sector + s += (
" that has the following areas: {" + target_sector_areas + "}\n") persona_name
+ " is going to "
+ target_sector
+ " that has the following areas: {"
+ target_sector_areas
+ "}\n"
)
s += """* Stay in the current area if the activity can be done there. s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary.""" * NEVER go into other people's rooms unless necessary."""
s += (persona_name + " is " + current_action + ". For " + next_action + s += (
", " + persona_name + "should go to the following area in " + persona_name
target_sector) + " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ "should go to the following area in "
+ target_sector
)
s += " (MUST pick one of {" + target_sector_areas + "}):\n" s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {" s += "Answer: {"
return {"prompt": s, "max_tokens": 5, "stop": "}"} return {"prompt": s, "max_tokens": 5, "stop": "}"}
import argparse import argparse
from functools import partial
import json import json
import time import time
from functools import partial
from pathlib import Path from pathlib import Path
from agent_functions import (
action_location_object_prompt,
action_location_sector_prompt,
generate_event_triple_prompt,
generate_pronunciatio_prompt,
poignancy_event_prompt,
)
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_other_args_and_parse, add_common_other_args_and_parse,
call_generate_lightllm, call_generate_lightllm,
call_generate_vllm,
call_generate_srt_raw, call_generate_srt_raw,
call_generate_vllm,
) )
from sglang.utils import read_jsonl, dump_state_text from sglang.utils import dump_state_text, read_jsonl
from agent_functions import (
poignancy_event_prompt,
generate_event_triple_prompt,
generate_pronunciatio_prompt,
action_location_sector_prompt,
action_location_object_prompt,
)
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_events] lines = read_jsonl(args.data_path)[: args.num_events]
mapping = { mapping = {
"poignancy_event": poignancy_event_prompt, "poignancy_event": poignancy_event_prompt,
"generate_event_triple": generate_event_triple_prompt, "generate_event_triple": generate_event_triple_prompt,
...@@ -46,7 +46,7 @@ def main(args): ...@@ -46,7 +46,7 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp( model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
...@@ -55,12 +55,16 @@ def main(args): ...@@ -55,12 +55,16 @@ def main(args):
) )
def call_generate(prompt, temperature, max_tokens, stop): def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen( out = (
model
+ prompt
+ gen(
name="result", name="result",
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
stop=stop, stop=stop,
) )
)
return out["result"] return out["result"]
else: else:
......
...@@ -2,24 +2,24 @@ import argparse ...@@ -2,24 +2,24 @@ import argparse
import json import json
import time import time
from agent_functions import (
action_location_object,
action_location_sector,
generate_event_triple,
generate_pronunciatio,
poignancy_event,
)
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import read_jsonl, dump_state_text from sglang.utils import dump_state_text, read_jsonl
from agent_functions import (
poignancy_event,
generate_event_triple,
generate_pronunciatio,
action_location_sector,
action_location_object,
)
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_events] lines = read_jsonl(args.data_path)[: args.num_events]
mapping = { mapping = {
"poignancy_event": poignancy_event, "poignancy_event": poignancy_event,
"generate_event_triple": generate_event_triple, "generate_event_triple": generate_event_triple,
......
import argparse import argparse
import ast import ast
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
...@@ -32,7 +37,7 @@ def get_few_shot_examples(lines, k): ...@@ -32,7 +37,7 @@ def get_few_shot_examples(lines, k):
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -50,7 +55,7 @@ def main(args): ...@@ -50,7 +55,7 @@ def main(args):
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -68,19 +73,31 @@ def main(args): ...@@ -68,19 +73,31 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url) call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop): def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=temperature, stop=stop) model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"] return out["answer"]
elif args.backend == "lmql": elif args.backend == "lmql":
import lmql import lmql
model = lmql.model(args.model_path,
endpoint=f"{args.host}:{args.port}") model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model) @lmql.query(model=model)
async def program(question): async def program(question):
...@@ -103,7 +120,8 @@ def main(args): ...@@ -103,7 +120,8 @@ def main(args):
prompt=few_shot_examples + questions[i], prompt=few_shot_examples + questions[i],
temperature=0, temperature=0,
max_tokens=256, max_tokens=256,
stop="Question") stop="Question",
)
states[i] = answer states[i] = answer
tic = time.time() tic = time.time()
...@@ -118,12 +136,18 @@ def main(args): ...@@ -118,12 +136,18 @@ def main(args):
async def batched_call(batch_size): async def batched_call(batch_size):
for i in range(0, len(questions), batch_size): for i in range(0, len(questions), batch_size):
tasks = [] tasks = []
for q in questions[i:i+batch_size]: for q in questions[i : i + batch_size]:
tasks.append(call_generate(few_shot_examples + q, tasks.append(
temperature=0, max_tokens=256, stop="Question")) call_generate(
few_shot_examples + q,
temperature=0,
max_tokens=256,
stop="Question",
)
)
rets = await asyncio.gather(*tasks) rets = await asyncio.gather(*tasks)
for j in range(len(rets)): for j in range(len(rets)):
states[i+j] = rets[j] states[i + j] = rets[j]
tic = time.time() tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel)) asyncio.run(batched_call(batch_size=args.parallel))
...@@ -154,7 +178,7 @@ def main(args): ...@@ -154,7 +178,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -5,9 +5,12 @@ import re ...@@ -5,9 +5,12 @@ import re
import time import time
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
...@@ -28,7 +31,7 @@ def get_few_shot_examples(lines, k): ...@@ -28,7 +31,7 @@ def get_few_shot_examples(lines, k):
def get_answer_value(answer_str): def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "") answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str) numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1: if len(numbers) < 1:
return INVALID return INVALID
try: try:
...@@ -46,7 +49,7 @@ def main(args): ...@@ -46,7 +49,7 @@ def main(args):
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
...@@ -73,7 +76,12 @@ def main(args): ...@@ -73,7 +76,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = few_shot_gsm8k.run_batch( states = few_shot_gsm8k.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
preds = [] preds = []
...@@ -101,7 +109,7 @@ def main(args): ...@@ -101,7 +109,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
import json import json
from functools import partial
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_select_lightllm,
call_select_vllm,
)
from sglang.utils import read_jsonl from sglang.utils import read_jsonl
...@@ -34,7 +39,7 @@ def main(args): ...@@ -34,7 +39,7 @@ def main(args):
questions = [] questions = []
choices = [] choices = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"]) choices.append(lines[i]["endings"])
labels.append(lines[i]["label"]) labels.append(lines[i]["label"])
...@@ -51,7 +56,11 @@ def main(args): ...@@ -51,7 +56,11 @@ def main(args):
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, select from guidance import models, select
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_select(context, choices): def call_select(context, choices):
out = model + context + select(choices, name="answer") out = model + context + select(choices, name="answer")
...@@ -61,8 +70,10 @@ def main(args): ...@@ -61,8 +70,10 @@ def main(args):
elif args.backend == "lmql": elif args.backend == "lmql":
import lmql import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}") model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model) @lmql.query(model=model)
async def program(ctx, choices): async def program(ctx, choices):
...@@ -83,8 +94,8 @@ def main(args): ...@@ -83,8 +94,8 @@ def main(args):
# Use thread pool # Use thread pool
def get_one_answer(i): def get_one_answer(i):
preds[i] = call_select( preds[i] = call_select(
context=few_shot_examples + questions[i], context=few_shot_examples + questions[i], choices=choices[i]
choices=choices[i]) )
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
...@@ -98,13 +109,13 @@ def main(args): ...@@ -98,13 +109,13 @@ def main(args):
async def batched_call(batch_size): async def batched_call(batch_size):
for i in range(0, len(questions), batch_size): for i in range(0, len(questions), batch_size):
tasks = [] tasks = []
for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]): for q, c in zip(
tasks.append(call_select( questions[i : i + batch_size], choices[i : i + batch_size]
context=few_shot_examples + q, ):
choices=c)) tasks.append(call_select(context=few_shot_examples + q, choices=c))
rets = await asyncio.gather(*tasks) rets = await asyncio.gather(*tasks)
for j in range(len(rets)): for j in range(len(rets)):
preds[i+j] = rets[j] preds[i + j] = rets[j]
tic = time.time() tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel)) asyncio.run(batched_call(batch_size=args.parallel))
...@@ -128,7 +139,7 @@ def main(args): ...@@ -128,7 +139,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -3,7 +3,11 @@ import json ...@@ -3,7 +3,11 @@ import json
import time import time
import numpy as np import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl from sglang.utils import read_jsonl
...@@ -31,14 +35,11 @@ def main(args): ...@@ -31,14 +35,11 @@ def main(args):
questions = [] questions = []
choices = [] choices = []
labels = [] labels = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"]) choices.append(lines[i]["endings"])
labels.append(lines[i]["label"]) labels.append(lines[i]["label"])
arguments = [ arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
{"question": q, "choices": c}
for q, c in zip(questions, choices)
]
##################################### #####################################
######### SGL Program Begin ######### ######### SGL Program Begin #########
...@@ -61,7 +62,12 @@ def main(args): ...@@ -61,7 +62,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
rets = few_shot_hellaswag.run_batch( rets = few_shot_hellaswag.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
latency = time.time() - tic latency = time.time() - tic
...@@ -82,7 +88,7 @@ def main(args): ...@@ -82,7 +88,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -4,13 +4,14 @@ import time ...@@ -4,13 +4,14 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_other_args_and_parse, add_common_other_args_and_parse,
call_generate_outlines, call_generate_outlines,
) )
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
from tqdm import tqdm
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import time import time
import sglang as sgl import sglang as sgl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
...@@ -63,7 +63,9 @@ def main(args): ...@@ -63,7 +63,9 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel, progress_bar=True) states = json_decode.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
......
...@@ -5,12 +5,13 @@ from concurrent.futures import ThreadPoolExecutor ...@@ -5,12 +5,13 @@ from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
import guidance import guidance
from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_other_args_and_parse, add_common_other_args_and_parse,
call_generate_outlines, call_generate_outlines,
) )
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
from tqdm import tqdm
# there are some FSM bugs with json regex converted from pydantic model # there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead # here use a string regex instead
......
...@@ -15,16 +15,17 @@ On the client side, run: ...@@ -15,16 +15,17 @@ On the client side, run:
--tokenizer <your_model> --dataset <target_dataset> \ --tokenizer <your_model> --dataset <target_dataset> \
--request-rate <request_rate> --request-rate <request_rate>
""" """
import argparse import argparse
import asyncio import asyncio
import json import json
import random import random
import time import time
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
from tqdm.asyncio import tqdm_asyncio
import aiohttp import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm_asyncio
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
...@@ -41,10 +42,7 @@ def sample_requests( ...@@ -41,10 +42,7 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"]) (data["conversations"][0]["value"], data["conversations"][1]["value"])
...@@ -185,9 +183,17 @@ async def benchmark( ...@@ -185,9 +183,17 @@ async def benchmark(
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt, task = asyncio.create_task(
prompt_len, output_len, send_request(
best_of, use_beam_search)) backend,
api_url,
prompt,
prompt_len,
output_len,
best_of,
use_beam_search,
)
)
tasks.append(task) tasks.append(task)
await tqdm_asyncio.gather(*tasks) await tqdm_asyncio.gather(*tasks)
...@@ -202,8 +208,16 @@ def main(args: argparse.Namespace): ...@@ -202,8 +208,16 @@ def main(args: argparse.Namespace):
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, asyncio.run(
args.use_beam_search, args.request_rate)) benchmark(
args.backend,
api_url,
input_requests,
args.best_of,
args.use_beam_search,
args.request_rate,
)
)
benchmark_end_time = time.perf_counter() benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s") print(f"Total time: {benchmark_time:.2f} s")
...@@ -212,43 +226,61 @@ def main(args: argparse.Namespace): ...@@ -212,43 +226,61 @@ def main(args: argparse.Namespace):
# Compute the latency statistics. # Compute the latency statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
print(f"Average latency: {avg_latency:.2f} s") print(f"Average latency: {avg_latency:.2f} s")
avg_per_token_latency = np.mean([ avg_per_token_latency = np.mean(
[
latency / (prompt_len + output_len) latency / (prompt_len + output_len)
for prompt_len, output_len, latency in REQUEST_LATENCY for prompt_len, output_len, latency in REQUEST_LATENCY
]) ]
)
print(f"Average latency per token: {avg_per_token_latency:.2f} s") print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([ avg_per_output_token_latency = np.mean(
latency / output_len [latency / output_len for _, output_len, latency in REQUEST_LATENCY]
for _, output_len, latency in REQUEST_LATENCY )
]) print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s")
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput."
parser.add_argument("--backend", type=str, default="vllm", )
choices=["vllm", "tgi", "srt", "lightllm"]) parser.add_argument(
"--backend",
type=str,
default="vllm",
choices=["vllm", "tgi", "srt", "lightllm"],
)
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True, parser.add_argument(
help="Path to the dataset.") "--dataset", type=str, required=True, help="Path to the dataset."
parser.add_argument("--tokenizer", type=str, required=True, )
help="Name or path of the tokenizer.") parser.add_argument(
parser.add_argument("--best-of", type=int, default=1, "--tokenizer", type=str, required=True, help="Name or path of the tokenizer."
help="Generates `best_of` sequences per prompt and " )
"returns the best one.") parser.add_argument(
"--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and " "returns the best one.",
)
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument(
help="Number of prompts to process.") "--num-prompts", type=int, default=1000, help="Number of prompts to process."
parser.add_argument("--request-rate", type=float, default=float("inf"), )
parser.add_argument(
"--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, " help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. " "then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize " "Otherwise, we use Poisson process to synthesize "
"the request arrival times.") "the request arrival times.",
)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument(
help='trust remote code from huggingface') "--trust-remote-code",
action="store_true",
help="trust remote code from huggingface",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import argparse import argparse
import json import json
import time
import re import re
import time
import numpy as np import numpy as np
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text from sglang.utils import dump_state_text
...@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): ...@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
dst_percent = dst_percents[j] dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)] query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [q for q in query_indices if query_indices = [
all(l <= src_index for l in line_obj["links"][q]) and q < src_index] q
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)] for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index] label = line_obj["values"][dst_index]
body = line_obj["lines"][:src_index+1] body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index]) suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4 body_part_len = len(body) // 4
arguments.append({ arguments.append(
{
"prefix": line_obj["prefix"], "prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]), "body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]), "body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]), "body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len:]), "body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix, "suffix": suffix,
}) }
)
labels.append(label) labels.append(label)
sum_src_indices.append(src_index) sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index) sum_dst_indices.append(dst_index)
...@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): ...@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
tic = time.time() tic = time.time()
states = line_retrieval.run_batch( states = line_retrieval.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
corrects = [] corrects = []
...@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): ...@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
if response_number == label: if response_number == label:
break break
correct = (response_number == label) correct = response_number == label
corrects.append(correct) corrects.append(correct)
# Log results # Log results
...@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): ...@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
"other": { "other": {
"num_questions": len(arguments), "num_questions": len(arguments),
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -4,12 +4,13 @@ Generate line data for line retrieval task. ...@@ -4,12 +4,13 @@ Generate line data for line retrieval task.
Usage: Usage:
python3 gen_data.py --number 1000 python3 gen_data.py --number 1000
""" """
import argparse import argparse
from collections import defaultdict
import json import json
from collections import defaultdict
from tqdm import tqdm
import numpy as np import numpy as np
from tqdm import tqdm
def generate_lines(random_words, num_lines, redirect_ratio): def generate_lines(random_words, num_lines, redirect_ratio):
...@@ -42,11 +43,14 @@ def generate_lines(random_words, num_lines, redirect_ratio): ...@@ -42,11 +43,14 @@ def generate_lines(random_words, num_lines, redirect_ratio):
# Add redirect # Add redirect
if redirect_ratio > 0: if redirect_ratio > 0:
num_redirect_lines = int(len(lines) * redirect_ratio) num_redirect_lines = int(len(lines) * redirect_ratio)
redirect_indices = np.random.choice(np.arange(len(lines)), redirect_indices = np.random.choice(
size=(num_redirect_lines,), replace=False) np.arange(len(lines)), size=(num_redirect_lines,), replace=False
)
for i in redirect_indices: for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines)) target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
redirects[i] = target_idx redirects[i] = target_idx
# Build links and find sources # Build links and find sources
......
import argparse import argparse
import json import json
import time
import os import os
import time
import sglang as sgl
import tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text import sglang as sgl
from PIL import Image from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function @sgl.function
...@@ -17,17 +20,19 @@ def image_qa(s, image_file, question): ...@@ -17,17 +20,19 @@ def image_qa(s, image_file, question):
def main(args): def main(args):
lines = read_jsonl(args.question_file)[:args.num_questions] lines = read_jsonl(args.question_file)[: args.num_questions]
arguments = [ arguments = [
{"image_file": {
os.path.abspath(args.image_folder + "/" + l["image"]), "image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
"question": l["text"]} for l in lines "question": l["text"],
}
for l in lines
] ]
#arguments = [ # arguments = [
# {"image_file": # {"image_file":
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])), # Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
# "question": l["text"]} for l in lines # "question": l["text"]} for l in lines
#] # ]
states = [None] * len(lines) states = [None] * len(lines)
...@@ -41,17 +46,12 @@ def main(args): ...@@ -41,17 +46,12 @@ def main(args):
for i in tqdm.tqdm(range(len(lines))): for i in tqdm.tqdm(range(len(lines))):
image_file = arguments[i]["image_file"] image_file = arguments[i]["image_file"]
question = arguments[i]["question"] question = arguments[i]["question"]
ret = image_qa.run( ret = image_qa.run(image_file=image_file, question=question, temperature=0)
image_file=image_file,
question=question,
temperature=0)
states[i] = ret states[i] = ret
else: else:
states = image_qa.run_batch( states = image_qa.run_batch(
arguments, arguments, temperature=0, num_threads=args.parallel, progress_bar=True
temperature=0, )
num_threads=args.parallel,
progress_bar=True)
latency = time.time() - tic latency = time.time() - tic
print(f"Latency: {latency:.3f}") print(f"Latency: {latency:.3f}")
......
import os import os
# Create the 'images' directory if it doesn't exist # Create the 'images' directory if it doesn't exist
if not os.path.exists('images'): if not os.path.exists("images"):
os.makedirs('images') os.makedirs("images")
# Base URL # Base URL
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/" base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
......
import argparse import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
from sglang.test.test_utils import (
system_prompt = ( add_common_other_args_and_parse,
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency." call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
) )
from sglang.utils import dump_state_text, read_jsonl
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
dimension_prompts = [ dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.", "Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.", "Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.", "Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.", "Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.", "Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.", "Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
] ]
...@@ -31,12 +32,16 @@ def multi_dimension_judge(article, generate): ...@@ -31,12 +32,16 @@ def multi_dimension_judge(article, generate):
judges = [] judges = []
for i in range(len(dimension_prompts)): for i in range(len(dimension_prompts)):
comp = generate(s + comp = generate(
"USER: Please judge the quality based on the following metric. " + s
dimension_prompts[i] + " Please provide a single-paragraph judgement. " + + "USER: Please judge the quality based on the following metric. "
"Focus on the provided metric and do not say other things. " + dimension_prompts[i]
+ " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:', 'End your judgement paragraph with the word "END"\nJUDGE:',
max_tokens=256, stop="END") max_tokens=256,
stop="END",
)
judges.append(comp) judges.append(comp)
s += "I will judge the quality based on the following metrics.\n" s += "I will judge the quality based on the following metrics.\n"
...@@ -50,7 +55,7 @@ def multi_dimension_judge(article, generate): ...@@ -50,7 +55,7 @@ def multi_dimension_judge(article, generate):
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
states = [None] * len(lines) states = [None] * len(lines)
# Select backend # Select backend
...@@ -64,13 +69,20 @@ def main(args): ...@@ -64,13 +69,20 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0) generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop): def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=0, stop=stop) model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"] return out["answer"]
# warmup # warmup
...@@ -107,7 +119,7 @@ def main(args): ...@@ -107,7 +119,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -2,23 +2,22 @@ import argparse ...@@ -2,23 +2,22 @@ import argparse
import json import json
import time import time
import numpy as np
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
from sglang.utils import read_jsonl, dump_state_text add_common_sglang_args_and_parse,
select_sglang_backend,
system_prompt = (
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
) )
from sglang.utils import dump_state_text, read_jsonl
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
dimension_prompts = [ dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.", "Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.", "Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.", "Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.", "Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.", "Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.", "Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
] ]
...@@ -29,23 +28,31 @@ def multi_dimension_judge(s, article): ...@@ -29,23 +28,31 @@ def multi_dimension_judge(s, article):
forks = s.fork(len(dimension_prompts)) forks = s.fork(len(dimension_prompts))
for i in range(len(dimension_prompts)): for i in range(len(dimension_prompts)):
forks[i] += ("USER: Please judge the quality based on the following metric. " + forks[i] += (
dimension_prompts[i] + " Please provide a single-paragraph judgement. " + "USER: Please judge the quality based on the following metric. "
"Focus on the provided metric and do not say other things. " + dimension_prompts[i]
'End your judgement paragraph with the word "END"\nJUDGE:') + " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:'
)
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END") forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
forks.join() forks.join()
s += "I will judge the quality based on the following metrics.\n" s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)): for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + forks[i]["judgement"].strip() + "\n" s += (
dimension_prompts[i].split(":")[0]
+ ": "
+ forks[i]["judgement"].strip()
+ "\n"
)
s += "In summary, on a scale of 1 to 10, I would give the article a score of" s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += sgl.gen("score", max_tokens=2) s += sgl.gen("score", max_tokens=2)
def main(args): def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{"article": l} for l in lines] arguments = [{"article": l} for l in lines]
# Select backend # Select backend
...@@ -54,7 +61,12 @@ def main(args): ...@@ -54,7 +61,12 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = multi_dimension_judge.run_batch( states = multi_dimension_judge.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True) arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic latency = time.time() - tic
print(f"Latency: {latency:.3f}") print(f"Latency: {latency:.3f}")
...@@ -72,7 +84,7 @@ def main(args): ...@@ -72,7 +84,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
import argparse import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw from sglang.test.test_utils import (
from sglang.utils import read_jsonl, dump_state_text add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl
def json_decode(document, generate): def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n" s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n" s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += '{\n' s += "{\n"
s += ' "name": "' s += ' "name": "'
s += generate(s, max_tokens=8, stop='"') + '",\n' s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "country": "' s += ' "country": "'
...@@ -24,17 +28,19 @@ def json_decode(document, generate): ...@@ -24,17 +28,19 @@ def json_decode(document, generate):
s += generate(s, max_tokens=8, stop='"') + '",\n' s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "top 3 landmarks": "' s += ' "top 3 landmarks": "'
s += generate(s, max_tokens=24, stop='"') + '",\n' s += generate(s, max_tokens=24, stop='"') + '",\n'
s += '}\n' s += "}\n"
return s return s
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
arguments = [] arguments = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
arguments.append({ arguments.append(
{
"document": lines[i]["document"], "document": lines[i]["document"],
}) }
)
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
...@@ -48,13 +54,20 @@ def main(args): ...@@ -48,13 +54,20 @@ def main(args):
url = f"{args.host}:{args.port}/generate" url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0) generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance": elif args.backend == "guidance":
from guidance import models, gen from guidance import gen, models
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000) model = models.LlamaCpp(
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
n_gpu_layers=-1,
n_ctx=11000,
)
def generate(prompt, max_tokens, stop): def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer", out = (
max_tokens=max_tokens, temperature=0, stop=stop) model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"] return out["answer"]
# warmup # warmup
...@@ -91,7 +104,7 @@ def main(args): ...@@ -91,7 +104,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
...@@ -2,10 +2,12 @@ import argparse ...@@ -2,10 +2,12 @@ import argparse
import json import json
import time import time
import numpy as np
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend from sglang.test.test_utils import (
from sglang.utils import read_jsonl, dump_state_text add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function @sgl.function
...@@ -13,21 +15,31 @@ def json_decode(s, document): ...@@ -13,21 +15,31 @@ def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n" s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n" s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += '{\n' s += "{\n"
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n' s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n' s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
s += ' "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n' s += (
s += ' "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n' ' "air port code": "'
s += '}\n' + sgl.gen("air port code", max_tokens=8, stop='"')
+ '",\n'
)
s += (
' "top 3 landmarks": "'
+ sgl.gen("landmarks", max_tokens=24, stop='"')
+ '",\n'
)
s += "}\n"
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
arguments = [] arguments = []
for i in range(len(lines[:args.num_questions])): for i in range(len(lines[: args.num_questions])):
arguments.append({ arguments.append(
{
"document": lines[i]["document"], "document": lines[i]["document"],
}) }
)
# Select backend # Select backend
backend = select_sglang_backend(args) backend = select_sglang_backend(args)
...@@ -36,7 +48,8 @@ def main(args): ...@@ -36,7 +48,8 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
states = json_decode.run_batch( states = json_decode.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True) arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
...@@ -55,7 +68,7 @@ def main(args): ...@@ -55,7 +68,7 @@ def main(args):
"other": { "other": {
"num_questions": args.num_questions, "num_questions": args.num_questions,
"parallel": args.parallel, "parallel": args.parallel,
} },
} }
fout.write(json.dumps(value) + "\n") fout.write(json.dumps(value) + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment