Unverified Commit 53754d41 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #679 from EleutherAI/fix-padding-ranks

[Refactor] Fix padding ranks
parents e5161a6d abccc756
...@@ -191,6 +191,8 @@ def evaluate( ...@@ -191,6 +191,8 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
padding_requests = collections.defaultdict(int)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -239,6 +241,7 @@ def evaluate( ...@@ -239,6 +241,7 @@ def evaluate(
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank] numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
...@@ -249,8 +252,8 @@ def evaluate( ...@@ -249,8 +252,8 @@ def evaluate(
for req in reqs: for req in reqs:
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (numpad > 0): if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
for _ in range(numpad): for _ in range(padding_requests[reqtype]):
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
# run requests through model # run requests through model
...@@ -260,8 +263,8 @@ def evaluate( ...@@ -260,8 +263,8 @@ def evaluate(
for x, req in zip(resps, cloned_reqs): for x, req in zip(resps, cloned_reqs):
req.resps.append(x) req.resps.append(x)
if lm.world_size > 1: if lm.world_size > 1:
lm.accelerator.wait_for_everyone() lm.accelerator.wait_for_everyone()
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment