Unverified Commit 4b460e72 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix flash batch filtering (#220)

parent 1ffea36e
...@@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch): ...@@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch):
position_ids.append(self.position_ids[idx]) position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length) cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
if not single_request: if not single_request:
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
# Add one padding # Add one padding
past_key_values.append(self.past_pad) past_key_values.append(self.past_pad)
...@@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch): ...@@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch):
if single_request: if single_request:
# Preallocate tensor for bs = 1 case # Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad( past_key_values = torch.nn.functional.pad(
self.past_key_values[0], past_key_values[0],
( (
0, 0,
0, 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment