Commit 3cdb38a7 authored by Victor SANH's avatar Victor SANH Committed by Lysandre Debut
Browse files

indents

parent ebd45980
......@@ -123,8 +123,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
if args.fp16:
if args.fp16:
try:
from apex import amp
except ImportError:
......@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
logger.info(" Total optimization steps = %d", t_total)
global_step = 1
epochs_trained = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
......@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
)
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
)
# Added here for reproductibility
set_seed(args)
set_seed(args)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
......@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
if args.version_2_with_negative:
if args.version_2_with_negative:
inputs.update({"is_impossible": batch[7]})
outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs
......@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model.zero_grad()
global_step += 1
# Log metrics
# Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training:
......@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
......@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Batch size = %d", args.eval_batch_size)
all_results = []
start_time = timeit.default_timer()
start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
......@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
torch.distributed.barrier()
# Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file
......@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset="pt",
threads=args.threads,
threads=args.threads,
)
if args.local_rank in [-1, 0]:
......@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
if output_examples:
......@@ -541,11 +541,11 @@ def main():
help="The input data dir. Should contain the .json files for the task."
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
parser.add_argument(
"--train_file",
default=None,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
default=None,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
......@@ -688,7 +688,7 @@ def main():
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
args = parser.parse_args()
if (
......@@ -743,7 +743,7 @@ def main():
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.model_type = args.model_type.lower()
......@@ -781,7 +781,7 @@ def main():
teacher = None
if args.local_rank == 0:
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
model.to(args.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment