"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "333536f696264082813b404656eb3c2f0aca1c20"
Commit f7e2ac01 authored by thomwolf's avatar thomwolf
Browse files

update barrier

parent 4d8c4337
...@@ -50,12 +50,6 @@ else: ...@@ -50,12 +50,6 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def barrier():
t = torch.randn((), device='cuda')
torch.distributed.all_reduce(t)
torch.cuda.synchronize()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -208,11 +202,11 @@ def main(): ...@@ -208,11 +202,11 @@ def main():
num_labels = len(label_list) num_labels = len(label_list)
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
if args.local_rank == 0: if args.local_rank == 0:
barrier() torch.distributed.barrier()
if args.fp16: if args.fp16:
model.half() model.half()
......
...@@ -183,10 +183,12 @@ def main(): ...@@ -183,10 +183,12 @@ def main():
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16: if args.fp16:
model.half() model.half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment