Commit d69b0b0e authored by thomwolf's avatar thomwolf
Browse files

fixes + clean up + mask is long

parent 3ddff783
...@@ -24,8 +24,8 @@ import logging ...@@ -24,8 +24,8 @@ import logging
import json import json
import math import math
import os import os
import six
import random import random
import six
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np import numpy as np
...@@ -750,7 +750,7 @@ def main(): ...@@ -750,7 +750,7 @@ def main():
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1: if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
...@@ -855,7 +855,7 @@ def main(): ...@@ -855,7 +855,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, start_positions, end_positions = batch input_ids, input_mask, segment_ids, start_positions, end_positions = batch
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
start_positions = start_positions.to(device) start_positions = start_positions.to(device)
end_positions = start_positions.to(device) end_positions = start_positions.to(device)
...@@ -904,12 +904,12 @@ def main(): ...@@ -904,12 +904,12 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaluating") logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"): for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask) start_logits, end_logits = model(input_ids, segment_ids, input_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment