Commit 8f859cd2 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

balance documents among hosts for wikitext

parent 3cfd23a1
......@@ -171,8 +171,8 @@ class HFLM(LM):
if self.world_size > 1:
cumulative_batches = 0 # balance token batches among iterators
# compute cumlative batches once -> could also just cache this can then use it later
for (string,) in tqdm([req.args for req in requests],disable=(self.rank != 0)):
# compute cumlative batches seen per host
for (string,) in tqdm([req.args for req in requests],disable=True):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
......@@ -193,7 +193,9 @@ class HFLM(LM):
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad_batches = max(gathered_item) - gathered_item[self.rank]
extra_pad = [('pad',)] if numpad_batches > 0 else []
# pad iterators with a pseudodocument
extra_pad = [('pad',)] if max(gathered_item) - min(gathered_item) > 0 else []
loglikelihoods = []
for (string,) in tqdm(extra_pad + [req.args for req in requests],disable=(self.rank != 0)):
......@@ -231,7 +233,7 @@ class HFLM(LM):
rolling_token_windows, disable_tqdm=True
)
if numpad_batches > 0:
if (numpad_batches > 0) or (string == 'pad'):
numpad_batches = 0
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment