Unverified Commit 86bca365 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix flash causal (#218)

parent afc5b999
......@@ -453,7 +453,10 @@ class FlashCausalLM(Model):
)
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) != 1:
if len(batch) == 1:
# present is already pre-padded
batch.past_key_values = present
else:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment