Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3318c246
Unverified
Commit
3318c246
authored
Mar 17, 2021
by
Stas Bekman
Committed by
GitHub
Mar 17, 2021
Browse files
make failure to find a resume checkpoint fatal + tests (#10777)
parent
cd8c93f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
13 deletions
+28
-13
src/transformers/trainer.py
src/transformers/trainer.py
+4
-1
tests/test_trainer.py
tests/test_trainer.py
+24
-12
No files found.
src/transformers/trainer.py
View file @
3318c246
...
@@ -876,7 +876,10 @@ class Trainer:
...
@@ -876,7 +876,10 @@ class Trainer:
if
resume_from_checkpoint
is
None
:
if
resume_from_checkpoint
is
None
:
raise
ValueError
(
f
"No valid checkpoint found in output directory (
{
self
.
args
.
output_dir
}
)"
)
raise
ValueError
(
f
"No valid checkpoint found in output directory (
{
self
.
args
.
output_dir
}
)"
)
if
resume_from_checkpoint
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
)):
if
resume_from_checkpoint
is
not
None
:
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
)):
raise
ValueError
(
f
"Can't find a valid checkpoint at
{
resume_from_checkpoint
}
"
)
logger
.
info
(
f
"Loading model from
{
resume_from_checkpoint
}
)."
)
logger
.
info
(
f
"Loading model from
{
resume_from_checkpoint
}
)."
)
if
self
.
deepspeed
:
if
self
.
deepspeed
:
...
...
tests/test_trainer.py
View file @
3318c246
...
@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
return
return
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
kwargs
=
dict
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
=
get_regression_trainer
(
**
kwargs
)
trainer
.
train
()
trainer
.
train
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state
=
dataclasses
.
asdict
(
trainer
.
state
)
state
=
dataclasses
.
asdict
(
trainer
.
state
)
...
@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
# Reinitialize trainer
# Reinitialize trainer
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
=
get_regression_trainer
(
**
kwargs
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
...
@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
# Reinitialize trainer and load model
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
=
get_regression_trainer
(
**
kwargs
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
...
@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# With a regular model that is not a PreTrainedModel
# With a regular model that is not a PreTrainedModel
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
kwargs
=
dict
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
=
get_regression_trainer
(
**
kwargs
)
trainer
.
train
()
trainer
.
train
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state
=
dataclasses
.
asdict
(
trainer
.
state
)
state
=
dataclasses
.
asdict
(
trainer
.
state
)
...
@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
# Reinitialize trainer and load model
# Reinitialize trainer and load model
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
**
kwargs
)
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
...
@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
# Reinitialize trainer and load model
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
**
kwargs
)
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
...
@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
# Now check failures
# 1. fail to find a bogus checkpoint
trainer
=
get_regression_trainer
()
with
self
.
assertRaises
(
Exception
)
as
context
:
trainer
.
train
(
resume_from_checkpoint
=
f
"
{
checkpoint
}
-bogus"
)
self
.
assertTrue
(
"Can't find a valid checkpoint at"
in
str
(
context
.
exception
))
# 2. fail to find any checkpoint - due a fresh output_dir
output_dir2
=
self
.
get_auto_remove_tmp_dir
()
trainer
=
get_regression_trainer
(
output_dir
=
output_dir2
)
with
self
.
assertRaises
(
Exception
)
as
context
:
trainer
.
train
(
resume_from_checkpoint
=
True
)
self
.
assertTrue
(
"No valid checkpoint found in output directory"
in
str
(
context
.
exception
))
def
test_resume_training_with_gradient_accumulation
(
self
):
def
test_resume_training_with_gradient_accumulation
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment