Unverified Commit 5ebd8989 authored by elk-cloner's avatar elk-cloner Committed by GitHub
Browse files

fix dataset shuffling for Distributed training (#huggingface#3721) (#3766)

parent 7972a401
...@@ -317,8 +317,12 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke ...@@ -317,8 +317,12 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
) )
set_seed(args) # Added here for reproducibility set_seed(args) # Added here for reproducibility
for _ in train_iterator: for epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
if args.local_rank != -1:
train_sampler.set_epoch(epoch)
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment