Unverified Commit ca76618d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Take gradient accumulation into account when defining samplers (#15095)

* Take gradient accumulation into account when defining samplers

* style
parent 9dc8fb2f
......@@ -581,7 +581,7 @@ class Trainer:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.args.train_batch_size,
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
......@@ -589,7 +589,7 @@ class Trainer:
)
else:
return DistributedLengthGroupedSampler(
self.args.train_batch_size,
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment