Unverified Commit 9ad2fc3d authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Account for padding in inplen calculation

parent d8bf52c6
...@@ -289,6 +289,7 @@ class BaseLM(LM): ...@@ -289,6 +289,7 @@ class BaseLM(LM):
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = []
padding_length = None padding_length = None
...@@ -336,19 +337,20 @@ class BaseLM(LM): ...@@ -336,19 +337,20 @@ class BaseLM(LM):
inps.append(inp.unsqueeze(0)) # [1, padding_length] inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont) cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1 self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, cont_toks in zip( for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, cont_toks_list chunk, multi_logits, inps, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
inplen = logits.shape[0] inplen = inplen + (logits.shape[0] - padding_length)
logits = logits[inplen - contlen : inplen].unsqueeze( logits = logits[inplen - contlen : inplen].unsqueeze(
0 0
) # [1, seq, vocab] ) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment