Commit d8bf52c6 authored by Wang, Yi's avatar Wang, Yi
Browse files

fix p-tuning inaccuracy, because output logit contains virtual token length


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 441e6ac1
...@@ -289,7 +289,6 @@ class BaseLM(LM): ...@@ -289,7 +289,6 @@ class BaseLM(LM):
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = []
padding_length = None padding_length = None
...@@ -337,19 +336,19 @@ class BaseLM(LM): ...@@ -337,19 +336,19 @@ class BaseLM(LM):
inps.append(inp.unsqueeze(0)) # [1, padding_length] inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont) cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1 self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( for (cache_key, _, _), logits, inp, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list chunk, multi_logits, inps, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
inplen = logits.shape[0]
logits = logits[inplen - contlen : inplen].unsqueeze( logits = logits[inplen - contlen : inplen].unsqueeze(
0 0
) # [1, seq, vocab] ) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment