"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e0c50b274ab396ca41beec9e0e025820524a1513"
Unverified Commit f717d47f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_number_of_steps_in_training_with_ipex` (#17889)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0b0dd977
......@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment