Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b4e559cf
Unverified
Commit
b4e559cf
authored
Jan 28, 2021
by
Sylvain Gugger
Committed by
GitHub
Jan 28, 2021
Browse files
Deprecate model_path in Trainer.train (#9854)
parent
2ee9f9b6
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
96 additions
and
78 deletions
+96
-78
examples/language-modeling/run_clm.py
examples/language-modeling/run_clm.py
+4
-4
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+4
-4
examples/language-modeling/run_mlm_wwm.py
examples/language-modeling/run_mlm_wwm.py
+4
-4
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+4
-4
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+4
-4
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+4
-4
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+4
-4
examples/seq2seq/run_seq2seq.py
examples/seq2seq/run_seq2seq.py
+4
-4
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+4
-4
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+4
-4
src/transformers/integrations.py
src/transformers/integrations.py
+6
-6
src/transformers/trainer.py
src/transformers/trainer.py
+38
-20
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+7
-7
tests/test_trainer.py
tests/test_trainer.py
+5
-5
No files found.
examples/language-modeling/run_clm.py
View file @
b4e559cf
...
...
@@ -362,12 +362,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_mlm.py
View file @
b4e559cf
...
...
@@ -403,12 +403,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_mlm_wwm.py
View file @
b4e559cf
...
...
@@ -355,12 +355,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/language-modeling/run_plm.py
View file @
b4e559cf
...
...
@@ -384,12 +384,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/multiple-choice/run_swag.py
View file @
b4e559cf
...
...
@@ -342,12 +342,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/question-answering/run_qa.py
View file @
b4e559cf
...
...
@@ -463,12 +463,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/question-answering/run_qa_beam_search.py
View file @
b4e559cf
...
...
@@ -502,12 +502,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/seq2seq/run_seq2seq.py
View file @
b4e559cf
...
...
@@ -491,12 +491,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
examples/text-classification/run_glue.py
View file @
b4e559cf
...
...
@@ -399,12 +399,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
metrics
=
train_result
.
metrics
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
...
...
examples/token-classification/run_ner.py
View file @
b4e559cf
...
...
@@ -380,12 +380,12 @@ def main():
# Training
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
train_result
=
trainer
.
train
(
model_path
=
model_path
)
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
src/transformers/integrations.py
View file @
b4e559cf
...
...
@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
import
optuna
def
_objective
(
trial
,
checkpoint_dir
=
None
):
model_path
=
None
checkpoint
=
None
if
checkpoint_dir
:
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
checkpoint
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
trainer
.
objective
=
None
trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
,
trial
=
trial
)
# If there hasn't been any evaluation during the training loop.
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
metrics
=
trainer
.
evaluate
()
...
...
@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
import
ray
def
_objective
(
trial
,
local_trainer
,
checkpoint_dir
=
None
):
model_path
=
None
checkpoint
=
None
if
checkpoint_dir
:
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
checkpoint
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
local_trainer
.
objective
=
None
local_trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
local_trainer
.
train
(
resume_from_checkpoint
=
checkpoint
,
trial
=
trial
)
# If there hasn't been any evaluation during the training loop.
if
getattr
(
local_trainer
,
"objective"
,
None
)
is
None
:
metrics
=
local_trainer
.
evaluate
()
...
...
src/transformers/trainer.py
View file @
b4e559cf
...
...
@@ -676,17 +676,33 @@ class Trainer:
return
model
def
train
(
self
,
model_path
:
Optional
[
str
]
=
None
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]
=
None
):
def
train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
):
"""
Main training entry point.
Args:
model_path
(:obj:`str`, `optional`):
Local path to
the model if the model to train has been instantiated from a local path. If present,
training will resume from the optimizer/scheduler states loaded here.
resume_from_checkpoint
(:obj:`str`, `optional`):
Local path to
a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
present,
training will resume from the
model/
optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
if
"model_path"
in
kwargs
:
resume_from_checkpoint
=
kwargs
.
pop
(
"model_path"
)
warnings
.
warn
(
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
"instead."
,
FutureWarning
,
)
if
len
(
kwargs
)
>
0
:
raise
TypeError
(
f
"train() received got unexpected keyword arguments:
{
', '
.
join
(
list
(
kwargs
.
keys
()))
}
."
)
# This might change the seed so needs to run first.
self
.
_hp_search_setup
(
trial
)
...
...
@@ -701,13 +717,13 @@ class Trainer:
self
.
optimizer
,
self
.
lr_scheduler
=
None
,
None
# Load potential model checkpoint
if
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
WEIGHTS_NAME
)):
logger
.
info
(
f
"Loading model from
{
model_path
}
)."
)
if
resume_from_checkpoint
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
)):
logger
.
info
(
f
"Loading model from
{
resume_from_checkpoint
}
)."
)
if
isinstance
(
self
.
model
,
PreTrainedModel
):
self
.
model
=
self
.
model
.
from_pretrained
(
model_path
)
self
.
model
=
self
.
model
.
from_pretrained
(
resume_from_checkpoint
)
model_reloaded
=
True
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
WEIGHTS_NAME
))
state_dict
=
torch
.
load
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
))
self
.
model
.
load_state_dict
(
state_dict
)
# If model was re-initialized, put it on the right device and update self.model_wrapped
...
...
@@ -757,7 +773,7 @@ class Trainer:
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
# Check if saved optimizer or scheduler states exist
self
.
_load_optimizer_and_scheduler
(
model_path
)
self
.
_load_optimizer_and_scheduler
(
resume_from_checkpoint
)
model
=
self
.
model_wrapped
...
...
@@ -827,8 +843,10 @@ class Trainer:
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
model_path
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
)):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
if
resume_from_checkpoint
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
"trainer_state.json"
)
):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
resume_from_checkpoint
,
"trainer_state.json"
))
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
if
not
self
.
args
.
ignore_data_skip
:
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
...
...
@@ -1102,20 +1120,20 @@ class Trainer:
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
)
def
_load_optimizer_and_scheduler
(
self
,
model_path
):
def
_load_optimizer_and_scheduler
(
self
,
checkpoint
):
"""If optimizer and scheduler states exist, load them."""
if
model_path
is
None
:
if
checkpoint
is
None
:
return
if
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
)
):
# Load in optimizer and scheduler states
if
is_torch_tpu_available
():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
"cpu"
)
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
),
map_location
=
"cpu"
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
),
map_location
=
"cpu"
)
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
),
map_location
=
"cpu"
)
reissue_pt_warnings
(
caught_warnings
)
xm
.
send_cpu_data_to_device
(
optimizer_state
,
self
.
args
.
device
)
...
...
@@ -1125,15 +1143,15 @@ class Trainer:
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state
)
else
:
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
if
self
.
deepspeed
:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self
.
deepspeed
.
load_checkpoint
(
model_path
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
True
)
self
.
deepspeed
.
load_checkpoint
(
checkpoint
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
True
)
def
hyperparameter_search
(
self
,
...
...
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
b4e559cf
...
...
@@ -341,20 +341,20 @@ def main():
if
training_args
.
do_train
:
{
%-
if
cookiecutter
.
can_train_from_scratch
==
"False"
%
}
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
checkpoint
=
None
{
%-
elif
cookiecutter
.
can_train_from_scratch
==
"True"
%
}
if
last_checkpoint
is
not
None
:
model_path
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
model_args
.
model_name_or_path
is
not
None
and
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
model_path
=
model_args
.
model_name_or_path
checkpoint
=
model_args
.
model_name_or_path
else
:
model_path
=
None
checkpoint
=
None
{
%
endif
%
}
train_result
=
trainer
.
train
(
model_path
=
model_path
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
...
...
tests/test_trainer.py
View file @
b4e559cf
...
...
@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer and load model
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir
=
tmpdir
,
train_len
=
128
,
save_steps
=
5
,
learning_rate
=
0.1
,
pretrained
=
False
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
learning_rate
=
0.1
,
)
trainer
.
train
(
model_path
=
checkpoint
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment