"vscode:/vscode.git/clone" did not exist on "7448aac8b63f769e2381f4fb478fa8a08c8775d4"
Commit 074099f9 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

clarify comments around padding

parent 32d7593d
...@@ -156,8 +156,7 @@ def evaluate( ...@@ -156,8 +156,7 @@ def evaluate(
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
# we assume rank 0 always has largest iterator numpad = max(gathered_item) - gathered_item[lm.rank]
numpad = gathered_item[0] - gathered_item[lm.rank]
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
...@@ -168,7 +167,7 @@ def evaluate( ...@@ -168,7 +167,7 @@ def evaluate(
for req in reqs: for req in reqs:
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
if (lm.rank > 0) and (numpad > 0): if (numpad > 0):
for _ in range(numpad): for _ in range(numpad):
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
...@@ -215,7 +214,8 @@ def evaluate( ...@@ -215,7 +214,8 @@ def evaluate(
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
# distributed gather requires all ranks to have same dimensionality -> pad out with float32 min value # distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device = lm.device) metrics_tensor = torch.tensor(items, device = lm.device)
...@@ -223,7 +223,6 @@ def evaluate( ...@@ -223,7 +223,6 @@ def evaluate(
torch_device_tensor = lm.accelerator.pad_across_processes(metrics_tensor.to(torch.float32), pad_index = pad_value) torch_device_tensor = lm.accelerator.pad_across_processes(metrics_tensor.to(torch.float32), pad_index = pad_value)
gathered_item = lm.accelerator.gather(torch_device_tensor) gathered_item = lm.accelerator.gather(torch_device_tensor)
#TODO: This is required when we get a tensor with a tuple of info like (ppl, _bytes) from wikitext
if numitem > 0: if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:,0] != pad_value] gathered_filtered = gathered_item[gathered_item[:,0] != pad_value]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment