Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
022f6515
"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "cdfbdc0bd108a315ff8bf5ae5ef877084f30336a"
Unverified
Commit
022f6515
authored
Jul 02, 2024
by
Nicolas Patry
Committed by
GitHub
Jul 02, 2024
Browse files
Fixing graph capture for flash decoding. (#2163)
parent
4327210e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+3
-2
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
022f6515
...
...
@@ -926,7 +926,7 @@ class FlashCausalLM(Model):
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
input_lengths
=
Seqlen
(
input_lengths
=
input_lengths
)
input_lengths
_
=
Seqlen
(
input_lengths
=
input_lengths
)
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
...
...
@@ -939,7 +939,7 @@ class FlashCausalLM(Model):
kv_cache
=
self
.
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
_
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
...
...
@@ -947,6 +947,7 @@ class FlashCausalLM(Model):
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
input_lengths
=
Seqlen
(
input_lengths
=
input_lengths
)
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment