Unverified Commit 491ec989 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #533 from sywangyi/fix_ptun

fix p-tuning inaccuracy, because output logit contains virtual token …
parents 25dfd3f6 318bd988
......@@ -364,7 +364,7 @@ class BaseLM(LM):
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
......@@ -375,6 +375,7 @@ class BaseLM(LM):
# Slice to original seq length
contlen = len(cont_toks)
inplen = inplen + (logits.shape[0] - padding_length) # if "virtual tokens" (from prompt tuning) are added, inplen is larger
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment