"INSTALL/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "9615e7eaa0cb6d02ac3d94f0bf7d06f3ffc277ed"
Commit 994d8660 authored by thomwolf's avatar thomwolf
Browse files

fixing PYTORCH_PRETRAINED_BERT_CACHE use in examples

parent 2dd8f524
......@@ -495,7 +495,7 @@ def main():
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=cache_dir,
num_labels = num_labels)
......
......@@ -894,7 +894,7 @@ def main():
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
if args.fp16:
model.half()
......
......@@ -367,7 +367,7 @@ def main():
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_choices=4)
if args.fp16:
model.half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment