Commit 629bcfba authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

bugfixes missed from local branch

parent 09e91431
...@@ -200,7 +200,7 @@ def evaluate( ...@@ -200,7 +200,7 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner) # calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key # TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()): for doc_id, doc in itertools.islice(enumerate(task.test_docs()), lm.rank, None, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, None, lm.world_size):
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
......
...@@ -45,6 +45,9 @@ class HFLM(LM): ...@@ -45,6 +45,9 @@ class HFLM(LM):
else torch.device("cpu") else torch.device("cpu")
) )
else:
self._device = 'cpu'
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
......
...@@ -89,7 +89,6 @@ def main(): ...@@ -89,7 +89,6 @@ def main():
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
if results is not None:
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=args.model,
model_args=args.model_args, model_args=args.model_args,
...@@ -101,7 +100,7 @@ def main(): ...@@ -101,7 +100,7 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
) )
if results is not None:
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment