Unverified Commit 4d49ae1f authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix position ids (#215)

parent 8dbf009b
......@@ -32,7 +32,7 @@ def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
num_new_tokens = num_input_tokens - last_forward_num_tokens
# after context is processed, slice to latest token
if num_new_tokens in [0,1]:
if num_new_tokens == 1:
input_ids = input_ids[:, -1:]
return input_ids, last_forward_num_tokens + num_new_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment