Unverified Commit ab2cabb9 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Pass along seed to DistributedSampler (#11406)

* Pass along seed to DistributedSampler

* Add seed to DistributedLengthGroupedSampler
parent b24ead87
...@@ -547,6 +547,7 @@ class Trainer: ...@@ -547,6 +547,7 @@ class Trainer:
rank=self.args.process_index, rank=self.args.process_index,
lengths=lengths, lengths=lengths,
model_input_name=model_input_name, model_input_name=model_input_name,
seed=self.args.seed,
) )
else: else:
...@@ -562,10 +563,14 @@ class Trainer: ...@@ -562,10 +563,14 @@ class Trainer:
batch_size=self.args.per_device_train_batch_size, batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size, num_replicas=self.args.world_size,
rank=self.args.process_index, rank=self.args.process_index,
seed=self.args.seed,
) )
else: else:
return DistributedSampler( return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
) )
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment