"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "204d2513105b299cdc0e5f66d5778d2f6f871424"
Unverified Commit 1c9fcd0e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix RNG reload in resume training from epoch checkpoint (#17055)

* Fix RNG reload in resume training from epoch checkpoint

* Fix test
parent 6e17ba6a
...@@ -789,13 +789,16 @@ class ModuleUtilsMixin: ...@@ -789,13 +789,16 @@ class ModuleUtilsMixin:
Returns: Returns:
`int`: The total number of tokens. `int`: The total number of tokens.
""" """
if not hasattr(self, "warnings_issued"):
self.warnings_issued = {}
if self.main_input_name in input_dict: if self.main_input_name in input_dict:
return input_dict[self.main_input_name].numel() return input_dict[self.main_input_name].numel()
else: elif "estimate_tokens" not in self.warnings_issued:
logger.warning( logger.warning(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
) )
return 0 self.warnings_issued["estimate_tokens"] = True
return 0
def floating_point_ops( def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
...@@ -895,6 +898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -895,6 +898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save config and origin of the pretrained weights if given in model # Save config and origin of the pretrained weights if given in model
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
self.warnings_issued = {}
def post_init(self): def post_init(self):
""" """
......
...@@ -1151,7 +1151,8 @@ class Trainer: ...@@ -1151,7 +1151,8 @@ class Trainer:
kwargs: kwargs:
Additional keyword arguments used to hide deprecated arguments Additional keyword arguments used to hide deprecated arguments
""" """
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint if resume_from_checkpoint is False:
resume_from_checkpoint = None
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker.start() self._memory_tracker.start()
...@@ -1395,6 +1396,9 @@ class Trainer: ...@@ -1395,6 +1396,9 @@ class Trainer:
) )
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
step = -1 step = -1
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
......
...@@ -58,7 +58,6 @@ from transformers.testing_utils import ( ...@@ -58,7 +58,6 @@ from transformers.testing_utils import (
require_torch_bf16, require_torch_bf16,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32, require_torch_tf32,
require_torch_up_to_2_gpus, require_torch_up_to_2_gpus,
require_wandb, require_wandb,
...@@ -162,11 +161,12 @@ class AlmostAccuracy: ...@@ -162,11 +161,12 @@ class AlmostAccuracy:
class RegressionModelConfig(PretrainedConfig): class RegressionModelConfig(PretrainedConfig):
def __init__(self, a=0, b=0, double_output=False, **kwargs): def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.a = a self.a = a
self.b = b self.b = b
self.double_output = double_output self.double_output = double_output
self.random_torch = random_torch
self.hidden_size = 1 self.hidden_size = 1
...@@ -264,14 +264,18 @@ if is_torch_available(): ...@@ -264,14 +264,18 @@ if is_torch_available():
super().__init__(config) super().__init__(config)
self.a = nn.Parameter(torch.tensor(config.a).float()) self.a = nn.Parameter(torch.tensor(config.a).float())
self.b = nn.Parameter(torch.tensor(config.b).float()) self.b = nn.Parameter(torch.tensor(config.b).float())
self.random_torch = config.random_torch
def forward(self, input_x, labels=None, **kwargs): def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b y = input_x * self.a + self.b
torch_rand = torch.randn(1).squeeze() if self.random_torch:
torch_rand = torch.randn(1).squeeze()
np_rand = np.random.rand() np_rand = np.random.rand()
rand_rand = random.random() rand_rand = random.random()
y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand) if self.random_torch:
y += 0.05 * torch_rand
y += 0.05 * torch.tensor(np_rand + rand_rand)
if labels is None: if labels is None:
return (y,) return (y,)
...@@ -1016,33 +1020,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1016,33 +1020,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=True) trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception)) self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
@require_torch_non_multi_gpu
def test_resume_training_with_randomness(self): def test_resume_training_with_randomness(self):
# This test will fail flakily for more than 1 GPUs since the result will be slightly more different # For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
# TODO: investigate why it fails for 2 GPUs? # in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
# GPU 0 will call first and sometimes GPU 1).
random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
train_dataset = RegressionDataset(length=128) train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset() eval_dataset = RegressionDataset()
config = RegressionModelConfig(a=0, b=2) with self.subTest("Test every step"):
model = RegressionRandomPreTrainedModel(config) config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1) args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train() trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item() (a, b) = trainer.model.a.item(), trainer.model.b.item()
model = RegressionRandomPreTrainedModel(config) model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15")) trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)
with self.subTest("Test every epoch"):
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_strategy="epoch", learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
checkpoints = [d for d in os.listdir(tmp_dir) if d.startswith("checkpoint-")]
# There should be one checkpoint per epoch.
self.assertEqual(len(checkpoints), 3)
checkpoint_dir = sorted(checkpoints, key=lambda x: int(x.replace("checkpoint-", "")))[0]
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, checkpoint_dir))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
self.assertAlmostEqual(a, a1, delta=1e-8) self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8) self.assertAlmostEqual(b, b1, delta=1e-8)
# regression for this issue: https://github.com/huggingface/transformers/issues/12970 # regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self): def test_training_with_resume_from_checkpoint_false(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment