Unverified Commit 04826b0f authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #77 from davidefiocco/patch-1

Correct assignement for logits in classifier example
parents 063be09b e60e8a60
...@@ -605,7 +605,8 @@ def main(): ...@@ -605,7 +605,8 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad(): with torch.no_grad():
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy() logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy() label_ids = label_ids.to('cpu').numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment