Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4083a55a
Unverified
Commit
4083a55a
authored
Sep 28, 2020
by
Marcin Zabłocki
Committed by
GitHub
Sep 28, 2020
Browse files
Flos fix (#7384)
parent
ae3e84f3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
7 deletions
+34
-7
src/transformers/trainer.py
src/transformers/trainer.py
+21
-7
tests/test_trainer.py
tests/test_trainer.py
+13
-0
No files found.
src/transformers/trainer.py
View file @
4083a55a
...
@@ -695,7 +695,7 @@ class Trainer:
...
@@ -695,7 +695,7 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path
# set global_step to global_step of last saved checkpoint from model path
try
:
try
:
self
.
global_step
=
int
(
model_path
.
split
(
"-"
)[
-
1
].
split
(
os
.
path
.
sep
)[
0
])
self
.
global_step
=
int
(
model_path
.
split
(
"-"
)[
-
1
].
split
(
os
.
path
.
sep
)[
0
])
self
.
total_flos
=
getattr
(
model
.
config
,
"total_flos"
,
0
)
self
.
total_flos
=
getattr
(
self
.
_actual_model
(
model
)
.
config
,
"total_flos"
,
0
)
epochs_trained
=
self
.
global_step
//
num_update_steps_per_epoch
epochs_trained
=
self
.
global_step
//
num_update_steps_per_epoch
steps_trained_in_current_epoch
=
self
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
=
self
.
global_step
%
(
num_update_steps_per_epoch
)
...
@@ -1448,15 +1448,29 @@ class Trainer:
...
@@ -1448,15 +1448,29 @@ class Trainer:
:obj:`int`: The number of floating-point operations.
:obj:`int`: The number of floating-point operations.
"""
"""
if
isinstance
(
self
.
model
,
torch
.
nn
.
DataParallel
)
or
isinstance
(
model
=
self
.
_actual_model
(
self
.
model
)
self
.
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
model
=
self
.
model
.
module
else
:
model
=
self
.
model
if
hasattr
(
model
,
"floating_point_ops"
):
if
hasattr
(
model
,
"floating_point_ops"
):
return
model
.
floating_point_ops
(
inputs
)
return
model
.
floating_point_ops
(
inputs
)
else
:
else
:
return
0
return
0
@
staticmethod
def
_actual_model
(
model
:
Union
[
torch
.
nn
.
DataParallel
,
torch
.
nn
.
parallel
.
DistributedDataParallel
,
torch
.
nn
.
modules
.
Module
]
)
->
torch
.
nn
.
modules
.
Module
:
"""
Args:
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
Model object used during training
Returns:
:obj:`torch.nn.modules.Module`: unwrapped module
"""
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
)
or
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
model
=
model
.
module
else
:
model
=
model
return
model
tests/test_trainer.py
View file @
4083a55a
...
@@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer
=
get_regression_trainer
(
train_len
=
64
,
per_device_train_batch_size
=
16
,
gradient_accumulation_steps
=
5
)
trainer
=
get_regression_trainer
(
train_len
=
64
,
per_device_train_batch_size
=
16
,
gradient_accumulation_steps
=
5
)
train_output
=
trainer
.
train
()
train_output
=
trainer
.
train
()
self
.
assertEqual
(
train_output
.
global_step
,
int
(
self
.
n_epochs
))
self
.
assertEqual
(
train_output
.
global_step
,
int
(
self
.
n_epochs
))
def
test_flos_extraction
(
self
):
trainer
=
get_regression_trainer
(
learning_rate
=
0.1
)
def
assert_flos_extraction
(
trainer
,
wrapped_model_to_check
):
self
.
assertEqual
(
trainer
.
model
,
trainer
.
_actual_model
(
wrapped_model_to_check
))
self
.
assertGreaterEqual
(
getattr
(
trainer
.
_actual_model
(
wrapped_model_to_check
).
config
,
"total_flos"
,
0
),
0
)
# with plain model
assert_flos_extraction
(
trainer
,
trainer
.
model
)
# with enforced DataParallel
assert_flos_extraction
(
trainer
,
torch
.
nn
.
DataParallel
(
trainer
.
model
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment