Commit f7e2ac01 authored by thomwolf's avatar thomwolf
Browse files

update barrier

parent 4d8c4337
......@@ -50,12 +50,6 @@ else:
logger = logging.getLogger(__name__)
def barrier():
t = torch.randn((), device='cuda')
torch.distributed.all_reduce(t)
torch.cuda.synchronize()
def main():
parser = argparse.ArgumentParser()
......@@ -208,11 +202,11 @@ def main():
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
barrier() # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
if args.local_rank == 0:
barrier()
torch.distributed.barrier()
if args.fp16:
model.half()
......
......@@ -183,10 +183,12 @@ def main():
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment