Unverified Commit 53754d41 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #679 from EleutherAI/fix-padding-ranks

[Refactor] Fix padding ranks
parents e5161a6d abccc756
...@@ -191,6 +191,8 @@ def evaluate( ...@@ -191,6 +191,8 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
padding_requests = collections.defaultdict(int)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -239,6 +241,7 @@ def evaluate( ...@@ -239,6 +241,7 @@ def evaluate(
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank] numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
...@@ -249,8 +252,8 @@ def evaluate( ...@@ -249,8 +252,8 @@ def evaluate(
for req in reqs: for req in reqs:
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (numpad > 0): if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
for _ in range(numpad): for _ in range(padding_requests[reqtype]):
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
# run requests through model # run requests through model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment