Commit ba4da393 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix Falcon inputs ids

parent f03f72de
......@@ -43,6 +43,11 @@ class FalconModel(nn.Module):
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
# NOTE: falcon input ids contain full context
# after context is processed, slice to latest token
if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1:
input_ids = input_ids[:, self.blocks[0].attn.start_pos:]
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment