Commit 3701bd08 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add hacks for triviaqa

parent d15ee17a
......@@ -290,7 +290,7 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
if lm.world_size > 1:
if lm.world_size >= 1:
# if multigpu, then gather data across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
......@@ -298,26 +298,38 @@ def evaluate(
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
# strings = items[1]
# if numitem = 2:
# for i, string in enumerate(numitem[1]):
# numitem[1][i] = torch.tensor(list(string.encode("ascii")))
# print(string, numitem[1][i])
print(items)
if isinstance(items[0], str):
items = torch.distributed.all_gather_object(items, items)
print(items)
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
print(items)
continue
# items = items[0]
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
......@@ -331,6 +343,7 @@ def evaluate(
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[metric](items)
......
......@@ -233,7 +233,7 @@ class HFLM(LM):
@property
def max_gen_toks(self):
return 256
return 16
@property
def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment