Unverified Commit 9ad2fc3d authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Account for padding in inplen calculation

parent d8bf52c6
......@@ -289,6 +289,7 @@ class BaseLM(LM):
):
inps = []
cont_toks_list = []
inplens = []
padding_length = None
......@@ -336,19 +337,20 @@ class BaseLM(LM):
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, cont_toks in zip(
chunk, multi_logits, inps, cont_toks_list
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
inplen = logits.shape[0]
inplen = inplen + (logits.shape[0] - padding_length)
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment