"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "7dfdc0a5563abe80a85f9f7fa0c3b2ef458e7783"
Commit 27c1b656 authored by Lysandre Debut's avatar Lysandre Debut
Browse files

Fix error with global step in run_lm_finetuning.py

parent 24df44d9
...@@ -264,15 +264,19 @@ def train(args, train_dataset, model, tokenizer): ...@@ -264,15 +264,19 @@ def train(args, train_dataset, model, tokenizer):
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path): if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path try:
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) # set global_step to gobal_step of last saved checkpoint from model path
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) global_step = int(checkpoint_suffix)
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step") steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step) logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
logger.info(" Starting fine-tuning.")
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment