Unverified Commit 99709ee6 authored by Jasdeep Singh's avatar Jasdeep Singh Committed by GitHub
Browse files

loading saved model when n_classes != 2

Required to for: Assertion `t >= 0 && t < n_classes` failed,  if your default number of classes is not 2.
parent 8da280eb
...@@ -558,7 +558,7 @@ def main(): ...@@ -558,7 +558,7 @@ def main():
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file) model_state_dict = torch.load(output_model_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment