Commit 629bcfba authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

bugfixes missed from local branch

parent 09e91431
......@@ -200,7 +200,7 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys():
for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()):
for doc_id, doc in itertools.islice(enumerate(task.test_docs()), lm.rank, None, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, None, lm.world_size):
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx)
......
......@@ -45,6 +45,9 @@ class HFLM(LM):
else torch.device("cpu")
)
else:
self._device = 'cpu'
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
......
......@@ -89,19 +89,18 @@ def main():
print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
if results is not None:
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
dumped = json.dumps(results, indent=2)
print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment