Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
aadc9cb4
Unverified
Commit
aadc9cb4
authored
Nov 04, 2024
by
Travis Addair
Committed by
GitHub
Nov 04, 2024
Browse files
Fix prefix caching + speculative decoding (#2711)
parent
a5593ba8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
6 deletions
+13
-6
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+13
-6
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
aadc9cb4
...
@@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
...
@@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states
=
fsm_grammar_states
,
fsm_grammar_states
=
fsm_grammar_states
,
)
)
speculative_ids
=
(
# We skip computing the speculative_ids when the batch size is too large, so
torch
.
cat
([
b
.
speculative_ids
for
b
in
batches
],
dim
=
0
)
# we must check that all batches have them, otherwise they must be discarded
if
batches
[
0
].
speculative_ids
is
not
None
if
get_speculate
()
>
0
and
all
(
b
.
speculative_ids
is
not
None
for
b
in
batches
):
else
None
speculative_ids
=
torch
.
cat
([
b
.
speculative_ids
for
b
in
batches
],
dim
=
0
)
)
else
:
speculative_ids
=
None
if
adapter_segment_builder
is
not
None
:
if
adapter_segment_builder
is
not
None
:
adapter_segments
,
adapter_segment_indices
=
adapter_segment_builder
.
build
()
adapter_segments
,
adapter_segment_indices
=
adapter_segment_builder
.
build
()
...
@@ -1724,7 +1725,13 @@ class FlashCausalLM(Model):
...
@@ -1724,7 +1725,13 @@ class FlashCausalLM(Model):
new_position_ids
=
(
new_position_ids
=
(
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
).
view
(
-
1
)
).
view
(
-
1
)
slots
=
(
slots
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices
=
(
batch
.
slot_indices
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
slots
=
batch
.
slots
[
slot_indices
]
input_lengths
=
(
input_lengths
=
(
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
).
view
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment