Unverified Commit ab2cabb9 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Pass along seed to DistributedSampler (#11406)

* Pass along seed to DistributedSampler

* Add seed to DistributedLengthGroupedSampler
parent b24ead87
......@@ -547,6 +547,7 @@ class Trainer:
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
seed=self.args.seed,
)
else:
......@@ -562,10 +563,14 @@ class Trainer:
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
else:
return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
def get_train_dataloader(self) -> DataLoader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment