Unverified Commit 8546dc55 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix Trainer tests in a multiGPU env (#7458)

parent d0fd7154
......@@ -109,12 +109,15 @@ if is_torch_available():
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
......@@ -178,6 +181,7 @@ class TrainerIntegrationTest(unittest.TestCase):
best_model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
best_model.load_state_dict(state_dict)
best_model.to(trainer.args.device)
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
......@@ -360,8 +364,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
trainer.model = RegressionModel()
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, pretrained=False)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
......@@ -426,8 +429,8 @@ class TrainerIntegrationTest(unittest.TestCase):
eval_steps=5,
evaluation_strategy="steps",
load_best_model_at_end=True,
pretrained=False,
)
trainer.model = RegressionModel(a=1.5, b=2.5)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment