Unverified Commit f717d47f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_number_of_steps_in_training_with_ipex` (#17889)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0b0dd977
...@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Regular training has n_epochs * len(train_dl) steps # Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True) trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size) self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too): # Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer( trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
) )
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size)) self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored # If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer( trainer = get_regression_trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment