Commit c4447dab authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

clean up cumul batches for rolling + remove unused print

parent e6492441
......@@ -159,9 +159,9 @@ class HFLM(LM):
extra_pad = []
numpad_batches = 0
# balance token batches among iterators
if self.world_size > 1:
cumulative_batches = 0
cumulative_batches = 0 # balance token batches among iterators
# compute cumlative batches once -> could also just cache this can then use it later
for (string,) in tqdm([req.args for req in requests],disable=(self.rank != 0)):
rolling_token_windows = list(
......@@ -179,15 +179,13 @@ class HFLM(LM):
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
cumulative_batches += len(rolling_token_windows)
cum_batches_ranks = torch.tensor(cumulative_batches, device = self.device)
gathered_item = self.accelerator.gather(cum_batches_ranks).cpu().detach().numpy().tolist()
cumul_batches_ranks = torch.tensor(cumulative_batches, device = self.device)
gathered_item = self.accelerator.gather(cumul_batches_ranks).cpu().detach().numpy().tolist()
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad_batches = max(gathered_item) - gathered_item[self.rank]
extra_pad = [('pad',)] if numpad_batches > 0 else []
print(self.rank, numpad_batches)
loglikelihoods = []
for (string,) in tqdm(extra_pad + [req.args for req in requests],disable=(self.rank != 0)):
if numpad_batches > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment