Commit 3cdb38a7 authored by Victor SANH's avatar Victor SANH Committed by Lysandre Debut
Browse files

indents

parent ebd45980
...@@ -123,8 +123,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -123,8 +123,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
except ImportError: except ImportError:
...@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 1 global_step = 1
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path): if os.path.exists(args.model_name_or_path):
...@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange( train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
) )
# Added here for reproductibility # Added here for reproductibility
set_seed(args) set_seed(args)
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
...@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
if args.version_2_with_negative: if args.version_2_with_negative:
inputs.update({"is_impossible": batch[7]}) inputs.update({"is_impossible": batch[7]})
outputs = model(**inputs) outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs loss, start_logits_stu, end_logits_stu = outputs
...@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
# Log metrics # Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training: if args.local_rank == -1 and args.evaluate_during_training:
...@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model.module if hasattr(model, "module") else model model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training ) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin")) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
...@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
all_results = [] all_results = []
start_time = timeit.default_timer() start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"): for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval() model.eval()
...@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate: if args.local_rank not in [-1, 0] and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache # Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier() torch.distributed.barrier()
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
...@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate, is_training=not evaluate,
return_dataset="pt", return_dataset="pt",
threads=args.threads, threads=args.threads,
) )
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate: if args.local_rank == 0 and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache # Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier() torch.distributed.barrier()
if output_examples: if output_examples:
...@@ -541,11 +541,11 @@ def main(): ...@@ -541,11 +541,11 @@ def main():
help="The input data dir. Should contain the .json files for the task." help="The input data dir. Should contain the .json files for the task."
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
) )
parser.add_argument( parser.add_argument(
"--train_file", "--train_file",
default=None, default=None,
type=str, type=str,
help="The input training file. If a data dir is specified, will look for the file there" help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
) )
parser.add_argument( parser.add_argument(
...@@ -688,7 +688,7 @@ def main(): ...@@ -688,7 +688,7 @@ def main():
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
args = parser.parse_args() args = parser.parse_args()
if ( if (
...@@ -743,7 +743,7 @@ def main(): ...@@ -743,7 +743,7 @@ def main():
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() torch.distributed.barrier()
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
...@@ -781,7 +781,7 @@ def main(): ...@@ -781,7 +781,7 @@ def main():
teacher = None teacher = None
if args.local_rank == 0: if args.local_rank == 0:
# Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() torch.distributed.barrier()
model.to(args.device) model.to(args.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment