"docs/source/vscode:/vscode.git/clone" did not exist on "40ed717232bf87f42ce2d3c16a14f0015e9c5fa9"
Commit f24232cd authored by Lysandre Debut's avatar Lysandre Debut
Browse files

Fix error with global step in run_squad.py

parent 1b59b57b
...@@ -170,8 +170,10 @@ def train(args, train_dataset, model, tokenizer): ...@@ -170,8 +170,10 @@ def train(args, train_dataset, model, tokenizer):
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path): if os.path.exists(args.model_name_or_path):
try:
# set global_step to gobal_step of last saved checkpoint from model path # set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
global_step = int(checkpoint_suffix)
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
...@@ -179,6 +181,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -179,6 +181,8 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step) logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
logger.info(" Starting fine-tuning.")
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment