"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9b2540b5a762436a0ac2b603f1fce93451535156"
Unverified Commit 4083a55a authored by Marcin Zabłocki's avatar Marcin Zabłocki Committed by GitHub
Browse files

Flos fix (#7384)

parent ae3e84f3
......@@ -695,7 +695,7 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
......@@ -1448,15 +1448,29 @@ class Trainer:
:obj:`int`: The number of floating-point operations.
"""
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model
model = self._actual_model(self.model)
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else:
return 0
@staticmethod
def _actual_model(
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
) -> torch.nn.modules.Module:
"""
Args:
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
Model object used during training
Returns:
:obj:`torch.nn.modules.Module`: unwrapped module
"""
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
else:
model = model
return model
......@@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment