"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a315988baeddd64da3eb4f030ca804ef92a73d1f"
Unverified Commit b4e559cf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate model_path in Trainer.train (#9854)

parent 2ee9f9b6
...@@ -362,12 +362,12 @@ def main(): ...@@ -362,12 +362,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -403,12 +403,12 @@ def main(): ...@@ -403,12 +403,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -355,12 +355,12 @@ def main(): ...@@ -355,12 +355,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -384,12 +384,12 @@ def main(): ...@@ -384,12 +384,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -342,12 +342,12 @@ def main(): ...@@ -342,12 +342,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -463,12 +463,12 @@ def main(): ...@@ -463,12 +463,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -502,12 +502,12 @@ def main(): ...@@ -502,12 +502,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -491,12 +491,12 @@ def main(): ...@@ -491,12 +491,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -399,12 +399,12 @@ def main(): ...@@ -399,12 +399,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics metrics = train_result.metrics
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -380,12 +380,12 @@ def main(): ...@@ -380,12 +380,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
import optuna import optuna
def _objective(trial, checkpoint_dir=None): def _objective(trial, checkpoint_dir=None):
model_path = None checkpoint = None
if checkpoint_dir: if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir): for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR): if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir) checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None trainer.objective = None
trainer.train(model_path=model_path, trial=trial) trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = trainer.evaluate()
...@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
import ray import ray
def _objective(trial, local_trainer, checkpoint_dir=None): def _objective(trial, local_trainer, checkpoint_dir=None):
model_path = None checkpoint = None
if checkpoint_dir: if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir): for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR): if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir) checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None local_trainer.objective = None
local_trainer.train(model_path=model_path, trial=trial) local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None: if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate() metrics = local_trainer.evaluate()
......
...@@ -676,17 +676,33 @@ class Trainer: ...@@ -676,17 +676,33 @@ class Trainer:
return model return model
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
**kwargs
):
""" """
Main training entry point. Main training entry point.
Args: Args:
model_path (:obj:`str`, `optional`): resume_from_checkpoint (:obj:`str`, `optional`):
Local path to the model if the model to train has been instantiated from a local path. If present, Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
training will resume from the optimizer/scheduler states loaded here. present, training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search. The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs:
Additional keyword arguments used to hide deprecated arguments
""" """
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
"instead.",
FutureWarning,
)
if len(kwargs) > 0:
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
# This might change the seed so needs to run first. # This might change the seed so needs to run first.
self._hp_search_setup(trial) self._hp_search_setup(trial)
...@@ -701,13 +717,13 @@ class Trainer: ...@@ -701,13 +717,13 @@ class Trainer:
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint # Load potential model checkpoint
if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)): if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {model_path}).") logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(model_path) self.model = self.model.from_pretrained(resume_from_checkpoint)
model_reloaded = True model_reloaded = True
else: else:
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
...@@ -757,7 +773,7 @@ class Trainer: ...@@ -757,7 +773,7 @@ class Trainer:
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(model_path) self._load_optimizer_and_scheduler(resume_from_checkpoint)
model = self.model_wrapped model = self.model_wrapped
...@@ -827,8 +843,10 @@ class Trainer: ...@@ -827,8 +843,10 @@ class Trainer:
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): if resume_from_checkpoint is not None and os.path.isfile(
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) os.path.join(resume_from_checkpoint, "trainer_state.json")
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not self.args.ignore_data_skip: if not self.args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
...@@ -1102,20 +1120,20 @@ class Trainer: ...@@ -1102,20 +1120,20 @@ class Trainer:
if self.is_world_process_zero(): if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True) self._rotate_checkpoints(use_mtime=True)
def _load_optimizer_and_scheduler(self, model_path): def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them.""" """If optimizer and scheduler states exist, load them."""
if model_path is None: if checkpoint is None:
return return
if os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile( if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
os.path.join(model_path, "scheduler.pt") os.path.join(checkpoint, "scheduler.pt")
): ):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
if is_torch_tpu_available(): if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device. # On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu") optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu") lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(optimizer_state, self.args.device)
...@@ -1125,15 +1143,15 @@ class Trainer: ...@@ -1125,15 +1143,15 @@ class Trainer:
self.lr_scheduler.load_state_dict(lr_scheduler_state) self.lr_scheduler.load_state_dict(lr_scheduler_state)
else: else:
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
) )
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.deepspeed: if self.deepspeed:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function # Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True) self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
def hyperparameter_search( def hyperparameter_search(
self, self,
......
...@@ -341,20 +341,20 @@ def main(): ...@@ -341,20 +341,20 @@ def main():
if training_args.do_train: if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %} {%- if cookiecutter.can_train_from_scratch == "False" %}
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
{%- elif cookiecutter.can_train_from_scratch == "True" %} {%- elif cookiecutter.can_train_from_scratch == "True" %}
if last_checkpoint is not None: if last_checkpoint is not None:
model_path = last_checkpoint checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path checkpoint = model_args.model_name_or_path
else: else:
model_path = None checkpoint = None
{% endif %} {% endif %}
train_result = trainer.train(model_path=model_path) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer # Reinitialize trainer
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train(model_path=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
...@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer and load model # Reinitialize trainer and load model
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train(model_path=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
...@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
) )
trainer.train(model_path=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
...@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
) )
trainer.train(model_path=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
...@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
learning_rate=0.1, learning_rate=0.1,
) )
trainer.train(model_path=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment