Commit b7c3580a authored by lintangsutawika's avatar lintangsutawika
Browse files

reformatted

parent 86db4a4e
......@@ -88,7 +88,12 @@ def simple_evaluate(
if model_args is None:
model_args = ""
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
assert isinstance(model, lm_eval.api.model.LM)
......@@ -112,11 +117,15 @@ def simple_evaluate(
if lm.rank == 0:
# add info about the model and few shot config
results["config"] = {
"model": model if isinstance(model, str) else model.model.config._name_or_path,
"model": model
if isinstance(model, str)
else model.model.config._name_or_path,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [],
"batch_sizes": list(lm.batch_sizes.values())
if hasattr(lm, "batch_sizes")
else [],
"device": device,
"no_cache": no_cache,
"limit": limit,
......
......@@ -4,7 +4,9 @@ from tqdm import tqdm
import time
def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperature, stop):
def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
......@@ -46,8 +48,9 @@ class AnthropicLM(BaseLM):
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
@property
def eot_token_id(self):
......
......@@ -168,8 +168,8 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1: window_end - 1],
token_list[window_end - window_pred_len: window_end],
token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len : window_end],
)
predicted += window_pred_len
......
......@@ -17,17 +17,27 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=utils.MultiChoice(sorted(ALL_TASKS)))
parser.add_argument(
"--tasks", default=None, choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser.add_argument("--config", default=None)
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_batch_size", type=int, default=None,
help="Maximal batch size to try with --batch_size auto")
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
help="Maximal batch size to try with --batch_size auto",
)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=float, default=None,
parser.add_argument(
"--limit",
type=float,
default=None,
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.")
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
......
......@@ -10,7 +10,12 @@ from lm_eval.api.registry import ALL_TASKS
seq2seq_models = ["google/flan-t5-small"]
causal_models = ["gpt2", "facebook/opt-125m", "EleutherAI/gpt-neo-125m", "EleutherAI/pythia-160m"]
causal_models = [
"gpt2",
"facebook/opt-125m",
"EleutherAI/gpt-neo-125m",
"EleutherAI/pythia-160m",
]
model_names = seq2seq_models + causal_models
......@@ -51,22 +56,41 @@ def eval_models(args, branch=None):
results = {}
for model in args.models:
model_type = "hf-causal" if model in causal_models \
else "hf-seq2seq" if model in seq2seq_models else args.model
model_type = (
"hf-causal"
if model in causal_models
else "hf-seq2seq"
if model in seq2seq_models
else args.model
)
model_args = f"pretrained={model},{args.model_args}"
# TODO: split_and_pad_windows in AutoSeq2SeqLM doesn"t exist, #527
tasks = args.tasks if model in causal_models or model_type == "hf-causal" \
tasks = (
args.tasks
if model in causal_models or model_type == "hf-causal"
else list(filter(lambda task: task not in perplexity_tasks, args.tasks))
)
# TODO: OOM with auto for seq2seq models, also can OOM with llama
batch_size = args.batch_size if model in causal_models or model_type == "hf-causal" \
else 64 if args.batch_size == "auto" else args.batch_size
output_path = f"data/regression/{int(start_time)}-{branch}-{Path(model).name}.json"
command = f"python3 main.py --model {model_type} --model_args {model_args} --tasks {','.join(tasks)} " \
f"--num_fewshot {args.num_fewshot}{'' if args.limit is None else f' --limit {args.limit}'} " \
batch_size = (
args.batch_size
if model in causal_models or model_type == "hf-causal"
else 64
if args.batch_size == "auto"
else args.batch_size
)
output_path = (
f"data/regression/{int(start_time)}-{branch}-{Path(model).name}.json"
)
command = (
f"python3 main.py --model {model_type} --model_args {model_args} --tasks {','.join(tasks)} "
f"--num_fewshot {args.num_fewshot}{'' if args.limit is None else f' --limit {args.limit}'} "
f"--batch_size {batch_size} --no_cache --output_path {output_path}"
)
print(f"{'=' * 80}\nEvaluating {model} on {', '.join(tasks)} at {branch} with:\n\n{command}\n{'=' * 80}")
print(
f"{'=' * 80}\nEvaluating {model} on {', '.join(tasks)} at {branch} with:\n\n{command}\n{'=' * 80}"
)
ret = os.system(command)
......@@ -89,7 +113,9 @@ def extract_value(args, results, model, task, err=False):
if "acc,none" in results:
return results["acc,none"] if not err else results["acc_stderr,none"]
if (args.perplexity or "word_perplexity") + ",none" in results:
return results[(args.perplexity or "word_perplexity") + ",none"] if not err else 0
return (
results[(args.perplexity or "word_perplexity") + ",none"] if not err else 0
)
return 0
......@@ -109,13 +135,24 @@ def format_diff(args, results1, results2, model, task):
def main():
args = parse_args()
args.branches = args.branches.split(",") if type(args.branches) == str else args.branches
args.branches = (
args.branches.split(",") if type(args.branches) == str else args.branches
)
args.models = args.models.split(",") if type(args.models) == str else args.models
args.tasks = ALL_TASKS if args.tasks == "all_tasks" \
else utils.pattern_match(args.tasks.split(","), ALL_TASKS) if type(args.tasks) == str else args.tasks
args.tasks = (
ALL_TASKS
if args.tasks == "all_tasks"
else utils.pattern_match(args.tasks.split(","), ALL_TASKS)
if type(args.tasks) == str
else args.tasks
)
global initial_branch
initial_branch = subprocess.check_output("git branch --show-current", shell=True).decode("ascii").strip()
initial_branch = (
subprocess.check_output("git branch --show-current", shell=True)
.decode("ascii")
.strip()
)
# TODO: implement proper timing for each task
# TODO: reduce IO by sharing tasks between models?
......@@ -133,10 +170,16 @@ def main():
print(f"|task|{'|'.join(map(lambda model: Path(model).name, args.models))}|")
print(f"|--|{'--|' * len(args.models)}")
for task in args.tasks:
print(f"|{task} ({initial_branch})|{'|'.join(map(lambda model: format_value(args, results, model, task), args.models))}|")
print(
f"|{task} ({initial_branch})|{'|'.join(map(lambda model: format_value(args, results, model, task), args.models))}|"
)
for branch, branch_results, branch_runtime in runs:
print(f"|{task} ({branch})|{'|'.join(map(lambda model: format_value(args, branch_results, model, task), args.models))}|")
print(f"|{task} (diff)|{'|'.join(map(lambda model: format_diff(args, results, branch_results, model, task), args.models))}|")
print(
f"|{task} ({branch})|{'|'.join(map(lambda model: format_value(args, branch_results, model, task), args.models))}|"
)
print(
f"|{task} (diff)|{'|'.join(map(lambda model: format_diff(args, results, branch_results, model, task), args.models))}|"
)
print("")
print("|branch|runtime|%|")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment