"tests/vscode:/vscode.git/clone" did not exist on "2d3eac307e7de25429d673cbd6627e44c5931087"
Commit 3701bd08 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add hacks for triviaqa

parent d15ee17a
...@@ -290,7 +290,7 @@ def evaluate( ...@@ -290,7 +290,7 @@ def evaluate(
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
if lm.world_size > 1: if lm.world_size >= 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
...@@ -298,26 +298,38 @@ def evaluate( ...@@ -298,26 +298,38 @@ def evaluate(
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
# strings = items[1]
# distributed gather requires all ranks to have same dimensions # if numitem = 2:
# so we pad out with float32 min value # for i, string in enumerate(numitem[1]):
pad_value = torch.finfo(torch.float32).min # numitem[1][i] = torch.tensor(list(string.encode("ascii")))
metrics_tensor = torch.tensor(items, device=lm.device) # print(string, numitem[1][i])
print(items)
original_dtype = metrics_tensor.dtype # store original dtype if isinstance(items[0], str):
torch_device_tensor = lm.accelerator.pad_across_processes( items = torch.distributed.all_gather_object(items, items)
metrics_tensor.to(torch.float32), pad_index=pad_value print(items)
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else: else:
gathered_filtered = gathered_item[gathered_item != pad_value] print(items)
continue
# items = items[0]
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
gathered_item = ( if numitem > 0:
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist() gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
) else:
gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values # reconvert if we were passed a tuple of values
if numitem > 0: if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item] gathered_item = [tuple(g) for g in gathered_item]
...@@ -331,6 +343,7 @@ def evaluate( ...@@ -331,6 +343,7 @@ def evaluate(
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[metric](items) results[task_name][metric + "," + key] = task.aggregation()[metric](items)
......
...@@ -233,7 +233,7 @@ class HFLM(LM): ...@@ -233,7 +233,7 @@ class HFLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 16
@property @property
def batch_size(self): def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment