Unverified Commit e60e8a60 authored by Davide Fiocco's avatar Davide Fiocco Committed by GitHub
Browse files

Correct assignement for logits in classifier example

I tried to address https://github.com/huggingface/pytorch-pretrained-BERT/issues/76
should be correct, but there's likely a more efficient way.
parent 063be09b
......@@ -605,7 +605,8 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment