Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b4e559cf
Unverified
Commit
b4e559cf
authored
Jan 28, 2021
by
Sylvain Gugger
Committed by
GitHub
Jan 28, 2021
Browse files
Deprecate model_path in Trainer.train (#9854)
parent
2ee9f9b6
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
96 additions
and
78 deletions
+96
-78
examples/language-modeling/run_clm.py
examples/language-modeling/run_clm.py
+4
-4
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+4
-4
examples/language-modeling/run_mlm_wwm.py
examples/language-modeling/run_mlm_wwm.py
+4
-4
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+4
-4
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+4
-4
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+4
-4
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+4
-4
examples/seq2seq/run_seq2seq.py
examples/seq2seq/run_seq2seq.py
+4
-4
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+4
-4
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+4
-4
src/transformers/integrations.py
src/transformers/integrations.py
+6
-6
src/transformers/trainer.py
src/transformers/trainer.py
+38
-20
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+7
-7
tests/test_trainer.py
tests/test_trainer.py
+5
-5
No files found.
examples/language-modeling/run_clm.py
View file @
b4e559cf
...
...
@@ -362,12 +362,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_mlm.py
View file @
b4e559cf
...
...
@@ -403,12 +403,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_mlm_wwm.py
View file @
b4e559cf
...
...
@@ -355,12 +355,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_plm.py
View file @
b4e559cf
...
...
@@ -384,12 +384,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/multiple-choice/run_swag.py
View file @
b4e559cf
...
...
@@ -342,12 +342,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/question-answering/run_qa.py
View file @
b4e559cf
...
...
@@ -463,12 +463,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/question-answering/run_qa_beam_search.py
View file @
b4e559cf
...
...
@@ -502,12 +502,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/seq2seq/run_seq2seq.py
View file @
b4e559cf
...
...
@@ -491,12 +491,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/text-classification/run_glue.py
View file @
b4e559cf
...
...
@@ -399,12 +399,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
metrics
=
train_result
.
metrics
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
...
...
examples/token-classification/run_ner.py
View file @
b4e559cf
...
...
@@ -380,12 +380,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
src/transformers/integrations.py
View file @
b4e559cf
...
...
@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
import
optuna
def
_objective
(
trial
,
checkpoint_dir
=
None
):
model_path
=
None
checkpoint
=
None
if
checkpoint_dir
:
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
checkpoint
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
trainer
.
objective
=
None
trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
,
trial
=
trial
)
# If there hasn't been any evaluation during the training loop.
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
metrics
=
trainer
.
evaluate
()
...
...
@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
import
ray
def
_objective
(
trial
,
local_trainer
,
checkpoint_dir
=
None
):
model_path
=
None
checkpoint
=
None
if
checkpoint_dir
:
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
checkpoint
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
local_trainer
.
objective
=
None
local_trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
local_trainer
.
train
(
resume_from_checkpoint
=
checkpoint
,
trial
=
trial
)
# If there hasn't been any evaluation during the training loop.
if
getattr
(
local_trainer
,
"objective"
,
None
)
is
None
:
metrics
=
local_trainer
.
evaluate
()
...
...
src/transformers/trainer.py
View file @
b4e559cf
...
...
@@ -676,17 +676,33 @@ class Trainer:
return
model
def
train
(
self
,
model_path
:
Optional
[
str
]
=
None
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]
=
None
):
def
train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
):
"""
Main training entry point.
Args:
model_path
(:obj:`str`, `optional`):
Local path to
the model if the model to train has been instantiated from a local path. If present,
training will resume from the optimizer/scheduler states loaded here.
resume_from_checkpoint
(:obj:`str`, `optional`):
Local path to
a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
present,
training will resume from the
model/
optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
if
"model_path"
in
kwargs
:
resume_from_checkpoint
=
kwargs
.
pop
(
"model_path"
)
warnings
.
warn
(
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
"instead."
,
FutureWarning
,
)
if
len
(
kwargs
)
>
0
:
raise
TypeError
(
f
"train() received got unexpected keyword arguments:
{
', '
.
join
(
list
(
kwargs
.
keys
()))
}
."
)
# This might change the seed so needs to run first.
self
.
_hp_search_setup
(
trial
)
...
...
@@ -701,13 +717,13 @@ class Trainer:
self
.
optimizer
,
self
.
lr_scheduler
=
None
,
None
# Load potential model checkpoint
if
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
WEIGHTS_NAME
)):
logger
.
info
(
f
"Loading model from
{
model_path
}
)."
)
if
resume_from_checkpoint
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
)):
logger
.
info
(
f
"Loading model from
{
resume_from_checkpoint
}
)."
)
if
isinstance
(
self
.
model
,
PreTrainedModel
):
self
.
model
=
self
.
model
.
from_pretrained
(
model_path
)
self
.
model
=
self
.
model
.
from_pretrained
(
resume_from_checkpoint
)
model_reloaded
=
True
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
WEIGHTS_NAME
))
state_dict
=
torch
.
load
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
))
self
.
model
.
load_state_dict
(
state_dict
)
# If model was re-initialized, put it on the right device and update self.model_wrapped
...
...
@@ -757,7 +773,7 @@ class Trainer:
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
# Check if saved optimizer or scheduler states exist
self
.
_load_optimizer_and_scheduler
(
model_path
)
self
.
_load_optimizer_and_scheduler
(
resume_from_checkpoint
)
model
=
self
.
model_wrapped
...
...
@@ -827,8 +843,10 @@ class Trainer:
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
model_path
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
)):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
if
resume_from_checkpoint
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
"trainer_state.json"
)
):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
resume_from_checkpoint
,
"trainer_state.json"
))
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
if
not
self
.
args
.
ignore_data_skip
:
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
...
...
@@ -1102,20 +1120,20 @@ class Trainer:
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
)
def
_load_optimizer_and_scheduler
(
self
,
model_path
):
def
_load_optimizer_and_scheduler
(
self
,
checkpoint
):
"""If optimizer and scheduler states exist, load them."""
if
model_path
is
None
:
if
checkpoint
is
None
:
return
if
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
)
):
# Load in optimizer and scheduler states
if
is_torch_tpu_available
():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
"cpu"
)
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
),
map_location
=
"cpu"
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
),
map_location
=
"cpu"
)
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
),
map_location
=
"cpu"
)
reissue_pt_warnings
(
caught_warnings
)
xm
.
send_cpu_data_to_device
(
optimizer_state
,
self
.
args
.
device
)
...
...
@@ -1125,15 +1143,15 @@ class Trainer:
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state
)
else
:
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
if
self
.
deepspeed
:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self
.
deepspeed
.
load_checkpoint
(
model_path
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
True
)
self
.
deepspeed
.
load_checkpoint
(
checkpoint
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
True
)
def
hyperparameter_search
(
self
,
...
...
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
b4e559cf
...
...
@@ -341,20 +341,20 @@ def main():
if
training_args
.
do_train
:
{
%-
if
cookiecutter
.
can_train_from_scratch
==
"False"
%
}
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
checkpoint
=
None
{
%-
elif
cookiecutter
.
can_train_from_scratch
==
"True"
%
}
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
checkpoint
=
None
{
%
endif
%
}
train_result
=
trainer
.
train
(
model_path
=
model_path
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
tests/test_trainer.py
View file @
b4e559cf
...
...
@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer and load model
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
learning_rate
=
0.1
,
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment