Commit 33599556 authored by thomwolf's avatar thomwolf
Browse files

updating run_classif

parent 29b7b30e
...@@ -50,15 +50,6 @@ else: ...@@ -50,15 +50,6 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def average_distributed_scalar(scalar, args):
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
if args.local_rank == -1:
return scalar
scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size()
torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM)
return scalar_t.item()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -368,7 +359,7 @@ def main(): ...@@ -368,7 +359,7 @@ def main():
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForSequenceClassification.from_pretrained(args.bert_model)
model.to(device) model.to(device)
...@@ -453,10 +444,6 @@ def main(): ...@@ -453,10 +444,6 @@ def main():
preds = np.squeeze(preds) preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, out_label_ids) result = compute_metrics(task_name, preds, out_label_ids)
if args.local_rank != -1:
# Average over distributed nodes if needed
result = {key: average_distributed_scalar(value, args) for key, value in result.items()}
loss = tr_loss/global_step if args.do_train else None loss = tr_loss/global_step if args.do_train else None
result['eval_loss'] = eval_loss result['eval_loss'] = eval_loss
...@@ -530,10 +517,6 @@ def main(): ...@@ -530,10 +517,6 @@ def main():
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, out_label_ids) result = compute_metrics(task_name, preds, out_label_ids)
if args.local_rank != -1:
# Average over distributed nodes if needed
result = {key: average_distributed_scalar(value, args) for key, value in result.items()}
loss = tr_loss/global_step if args.do_train else None loss = tr_loss/global_step if args.do_train else None
result['eval_loss'] = eval_loss result['eval_loss'] = eval_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment