Commit b99ad796 authored by baberabb's avatar baberabb
Browse files

fix errors

parent a4188e1d
...@@ -22,8 +22,10 @@ def parse_args(): ...@@ -22,8 +22,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--pretrained", default="EleutherAI/pythia-70m", help="name of model to compare" "--pretrained", default="EleutherAI/pythia-70m", help="name of model to compare"
) )
parser.add_argument("--hf_args", help="huggingface model args <arg>=<value>") parser.add_argument(
parser.add_argument("--vllm_args", help="vllm model args <arg>=<value>") "--hf_args", help="huggingface model args <arg>=<value>", default=""
)
parser.add_argument("--vllm_args", help="vllm model args <arg>=<value>", default="")
parser.add_argument("--tasks", type=str, default="arc_easy,hellaswag") parser.add_argument("--tasks", type=str, default="arc_easy,hellaswag")
parser.add_argument( parser.add_argument(
"--samples", "--samples",
...@@ -37,7 +39,8 @@ def parse_args(): ...@@ -37,7 +39,8 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--batch", "--batch",
default="auto", type=int,
default=8,
) )
parser.add_argument( parser.add_argument(
"--verbosity", "--verbosity",
...@@ -49,31 +52,34 @@ def parse_args(): ...@@ -49,31 +52,34 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
tasks.initialize_tasks()
args = parse_args() args = parse_args()
tasks = args.tasks.split(",")
print(tasks)
hf_args = "," + args.hf_args hf_args = "," + args.hf_args
vllm_args = "," + args.vllm_args vllm_args = "," + args.vllm_args
results_hf = lm_eval.evaluator.simple_evaluate( results_hf = lm_eval.evaluator.simple_evaluate(
model="hf", model="hf",
model_args=f"pretrained={args.pretrained}" + hf_args, model_args=f"pretrained={args.pretrained}" + hf_args,
tasks=args.tasks, tasks=tasks,
limit=args.limit, limit=args.samples,
device=args.device, device=args.device,
batch=args.batch, batch_size=args.batch,
) )
results_vllm = lm_eval.evaluator.simple_evaluate( results_vllm = lm_eval.evaluator.simple_evaluate(
model="vllm", model="vllm",
model_args=f"pretrained={args.pretrained}" + vllm_args, model_args=f"pretrained={args.pretrained}" + vllm_args,
tasks=args.tasks, tasks=tasks,
limit=args.limit, limit=args.samples,
device=args.device, device=args.device,
batch=args.batch, batch_size=args.batch,
) )
all_res = {} all_res = {}
for task, res1, task2, res2 in zip( for task1, task2 in zip(
results_hf["results"].items(), results_vllm["results"].items() results_hf["results"].items(), results_vllm["results"].items()
): ):
assert task == task2 assert task1[0] == task2[0]
z, p_value = calculate_z_value(res1, res2, args.limit) z, p_value = calculate_z_value(task1[1], task2[1], args.samples)
all_res["task"] = {"z": z, "p_value": p_value} all_res[task1[0]] = {"z": z, "p_value": p_value}
assert p_value > 0.05 assert p_value > 0.05
eval_logger.info(all_res) eval_logger.info(all_res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment