Commit 3701bd08 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add hacks for triviaqa

parent d15ee17a
...@@ -290,7 +290,7 @@ def evaluate( ...@@ -290,7 +290,7 @@ def evaluate(
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
if lm.world_size > 1: if lm.world_size >= 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
...@@ -298,7 +298,19 @@ def evaluate( ...@@ -298,7 +298,19 @@ def evaluate(
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
# strings = items[1]
# if numitem = 2:
# for i, string in enumerate(numitem[1]):
# numitem[1][i] = torch.tensor(list(string.encode("ascii")))
# print(string, numitem[1][i])
print(items)
if isinstance(items[0], str):
items = torch.distributed.all_gather_object(items, items)
print(items)
else:
print(items)
continue
# items = items[0]
# distributed gather requires all ranks to have same dimensions # distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value # so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min pad_value = torch.finfo(torch.float32).min
...@@ -331,6 +343,7 @@ def evaluate( ...@@ -331,6 +343,7 @@ def evaluate(
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[metric](items) results[task_name][metric + "," + key] = task.aggregation()[metric](items)
......
...@@ -233,7 +233,7 @@ class HFLM(LM): ...@@ -233,7 +233,7 @@ class HFLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 16
@property @property
def batch_size(self): def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment