Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
cc8d627e
Unverified
Commit
cc8d627e
authored
Feb 14, 2025
by
Atream
Committed by
GitHub
Feb 14, 2025
Browse files
Merge pull request #301 from kvcache-ai/fix-cuda-graph-bug
warm_up before capture
parents
cadd5507
1946493f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
ktransformers/util/utils.py
ktransformers/util/utils.py
+12
-6
No files found.
ktransformers/util/utils.py
View file @
cc8d627e
...
@@ -18,6 +18,8 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -18,6 +18,8 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.util.textstream
import
TextStreamer
warm_uped
=
False
def
set_module
(
model
,
submodule_key
,
module
):
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
tokens
=
submodule_key
.
split
(
'.'
)
sub_tokens
=
tokens
[:
-
1
]
sub_tokens
=
tokens
[:
-
1
]
...
@@ -99,6 +101,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -99,6 +101,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
=
[]
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
if
cuda_graph_runner
is
None
:
use_cuda_graph
=
False
if
use_cuda_graph
:
if
use_cuda_graph
:
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
else
:
else
:
...
@@ -182,14 +186,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -182,14 +186,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
seq_length
+=
1
if
use_cuda_graph
:
cuda_graph_runner
=
None
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
else
:
cuda_graph_runner
=
None
start_time
=
time
.
time
()
start_time
=
time
.
time
()
for
_
in
range
(
1
,
max_new_tokens
):
for
i
in
range
(
1
,
max_new_tokens
):
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment