Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f717d47f
Unverified
Commit
f717d47f
authored
Jun 28, 2022
by
Yih-Dar
Committed by
GitHub
Jun 28, 2022
Browse files
Fix `test_number_of_steps_in_training_with_ipex` (#17889)
Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
0b0dd977
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
tests/trainer/test_trainer.py
tests/trainer/test_trainer.py
+2
-2
No files found.
tests/trainer/test_trainer.py
View file @
f717d47f
...
@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Regular training has n_epochs * len(train_dl) steps
# Regular training has n_epochs * len(train_dl) steps
trainer
=
get_regression_trainer
(
learning_rate
=
0.1
,
use_ipex
=
True
,
bf16
=
mix_bf16
,
no_cuda
=
True
)
trainer
=
get_regression_trainer
(
learning_rate
=
0.1
,
use_ipex
=
True
,
bf16
=
mix_bf16
,
no_cuda
=
True
)
train_output
=
trainer
.
train
()
train_output
=
trainer
.
train
()
self
.
assertEqual
(
train_output
.
global_step
,
self
.
n_epochs
*
64
/
self
.
batch_size
)
self
.
assertEqual
(
train_output
.
global_step
,
self
.
n_epochs
*
64
/
trainer
.
args
.
train_
batch_size
)
# Check passing num_train_epochs works (and a float version too):
# Check passing num_train_epochs works (and a float version too):
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
learning_rate
=
0.1
,
num_train_epochs
=
1.5
,
use_ipex
=
True
,
bf16
=
mix_bf16
,
no_cuda
=
True
learning_rate
=
0.1
,
num_train_epochs
=
1.5
,
use_ipex
=
True
,
bf16
=
mix_bf16
,
no_cuda
=
True
)
)
train_output
=
trainer
.
train
()
train_output
=
trainer
.
train
()
self
.
assertEqual
(
train_output
.
global_step
,
int
(
1.5
*
64
/
self
.
batch_size
))
self
.
assertEqual
(
train_output
.
global_step
,
int
(
1.5
*
64
/
trainer
.
args
.
train_
batch_size
))
# If we pass a max_steps, num_train_epochs is ignored
# If we pass a max_steps, num_train_epochs is ignored
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment