Commit 29b7b30e authored by thomwolf's avatar thomwolf
Browse files

updating evaluation on a single gpu

parent 7d2001aa
...@@ -306,10 +306,10 @@ def main(): ...@@ -306,10 +306,10 @@ def main():
logger.info(" Num steps = %d", num_train_optimization_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
...@@ -367,21 +367,13 @@ def main(): ...@@ -367,21 +367,13 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
# Distributed/fp16/parallel settings (optional) model.to(device)
model.to(device)
if args.fp16:
model.half()
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
### Evaluation ### Evaluation
if args.do_eval: if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), list(filter(None, args.bert_model.split('/'))).pop(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment