Commit a994bf40 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Fixing related to issue #83.

parent c6d9d539
...@@ -423,6 +423,12 @@ def main(): ...@@ -423,6 +423,12 @@ def main():
"mrpc": MrpcProcessor, "mrpc": MrpcProcessor,
} }
num_labels_task = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
}
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count() n_gpu = torch.cuda.device_count()
...@@ -461,6 +467,7 @@ def main(): ...@@ -461,6 +467,7 @@ def main():
raise ValueError("Task not found: %s" % (task_name)) raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() processor = processors[task_name]()
num_labels = num_labels_task[task_name]
label_list = processor.get_labels() label_list = processor.get_labels()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -474,7 +481,8 @@ def main(): ...@@ -474,7 +481,8 @@ def main():
# Prepare model # Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
num_labels = num_labels)
if args.fp16: if args.fp16:
model.half() model.half()
model.to(device) model.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment