"configs/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "023aef48786e65d81d03e6a5fb8c730a4aec8373"
Unverified Commit 5b7dcc73 authored by David Hall's avatar David Hall Committed by GitHub
Browse files

Seed _get_train_sampler's generator with arg seed to improve reproducibility (#15961)



* Seed get_train_sampler's generator with arg seed to improve reproducibility

and make the world_size<=1 code path more similar to the others

* move test file into trainer test explicitly

* dumb typo

* make style lint happy

* per discussion, switch to data_seed

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 70203b59
...@@ -591,7 +591,16 @@ class Trainer: ...@@ -591,7 +591,16 @@ class Trainer:
generator = None generator = None
if self.args.world_size <= 1 and _is_torch_generator_available: if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
# `args.seed`) if data_seed isn't provided.
# Further on in this method, we default to `args.seed` instead.
if self.args.data_seed is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
else:
seed = self.args.data_seed
generator.manual_seed(seed)
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
# Build the sampler. # Build the sampler.
if self.args.group_by_length: if self.args.group_by_length:
...@@ -620,7 +629,7 @@ class Trainer: ...@@ -620,7 +629,7 @@ class Trainer:
rank=self.args.process_index, rank=self.args.process_index,
lengths=lengths, lengths=lengths,
model_input_name=model_input_name, model_input_name=model_input_name,
seed=self.args.seed, seed=seed,
) )
else: else:
...@@ -638,14 +647,14 @@ class Trainer: ...@@ -638,14 +647,14 @@ class Trainer:
batch_size=self.args.per_device_train_batch_size, batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size, num_replicas=self.args.world_size,
rank=self.args.process_index, rank=self.args.process_index,
seed=self.args.seed, seed=seed,
) )
else: else:
return DistributedSampler( return DistributedSampler(
self.train_dataset, self.train_dataset,
num_replicas=self.args.world_size, num_replicas=self.args.world_size,
rank=self.args.process_index, rank=self.args.process_index,
seed=self.args.seed, seed=seed,
) )
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
......
...@@ -220,6 +220,10 @@ class TrainingArguments: ...@@ -220,6 +220,10 @@ class TrainingArguments:
seed (`int`, *optional*, defaults to 42): seed (`int`, *optional*, defaults to 42):
Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
[`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
data_seed (`int`, *optional*):
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
bf16 (`bool`, *optional*, defaults to `False`): bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture. This is an experimental API and it may change. NVIDIA architecture. This is an experimental API and it may change.
...@@ -539,6 +543,7 @@ class TrainingArguments: ...@@ -539,6 +543,7 @@ class TrainingArguments:
) )
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
bf16: bool = field( bf16: bool = field(
default=False, default=False,
metadata={ metadata={
......
...@@ -647,6 +647,67 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -647,6 +647,67 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
new_eval_dataset = RegressionDataset(length=128) new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
def test_sampler_seed(self):
# nb: we don't want to inherit from IterableDataset to hit the right code path
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i):
if (i < 0) or (i >= self.length):
raise IndexError
return {"input_ids": [i]}
class DummyModel(PreTrainedModel):
def __init__(self, num_params: int):
super().__init__(PretrainedConfig())
# Add some (unused) params. the point here is that randomness in model_init shouldn't influence
# data loader order.
self.params = nn.Parameter(torch.randn(num_params))
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
def _get_first_data_sample(num_params, seed, data_seed, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = Trainer(
model_init=lambda: DummyModel(num_params),
args=TrainingArguments(
output_dir=tmpdir,
**kwargs,
seed=seed,
data_seed=data_seed,
local_rank=-1,
),
train_dataset=DummyDataset(),
)
return next(iter(trainer.get_train_dataloader()))
# test that the seed is passed to the sampler
# the codepath we want to hit is world_size <= 1, and both group_by_length
for group_by_length in [True, False]:
sample42_1 = _get_first_data_sample(num_params=10, seed=42, data_seed=42, group_by_length=group_by_length)
sample42_2 = _get_first_data_sample(num_params=11, seed=42, data_seed=42, group_by_length=group_by_length)
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_2["input_ids"]))
# should get same samples with different seed, so long as data_seed is the same
sample42_3 = _get_first_data_sample(num_params=11, seed=11, data_seed=42, group_by_length=group_by_length)
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_3["input_ids"]))
# make sure we have some randomness in the samples if data_seed is different
others = [
_get_first_data_sample(num_params=i, seed=42, data_seed=i, group_by_length=group_by_length)
for i in range(10)
]
self.assertTrue(any(not torch.equal(sample42_1["input_ids"], sample["input_ids"]) for sample in others))
@require_torch_multi_gpu @require_torch_multi_gpu
def test_data_is_not_parallelized_when_model_is_parallel(self): def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel() model = RegressionModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment