Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
bb0ccc7b
Commit
bb0ccc7b
authored
Feb 17, 2025
by
ceerrep
Browse files
feat: add prefix cache for server
parent
c515cc49
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
132 additions
and
55 deletions
+132
-55
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+13
-1
ktransformers/server/args.py
ktransformers/server/args.py
+2
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+42
-21
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+71
-32
ktransformers/server/main.py
ktransformers/server/main.py
+4
-0
No files found.
ktransformers/models/custom_cache.py
View file @
bb0ccc7b
...
@@ -174,7 +174,19 @@ class StaticCache(transformers.StaticCache):
...
@@ -174,7 +174,19 @@ class StaticCache(transformers.StaticCache):
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
key_cache
[
layer_idx
].
zero_
()
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
0
def
remove_suffix
(
self
,
start_pos
):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
# In-place ops prevent breaking the static address
if
self
.
is_MLA
:
k_cache
=
self
.
key_cache
[
layer_idx
]
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
1
])[
start_pos
:].
zero_
()
else
:
self
.
key_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
value_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
start_pos
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
return
self
.
max_cache_len
\ No newline at end of file
ktransformers/server/args.py
View file @
bb0ccc7b
...
@@ -90,7 +90,8 @@ class ArgumentParser:
...
@@ -90,7 +90,8 @@ class ArgumentParser:
# user config
# user config
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# web config
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
bb0ccc7b
...
@@ -121,33 +121,53 @@ class KTransformersInterface(TransformersInterface):
...
@@ -121,33 +121,53 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
generated_ids
=
torch
.
zeros
(
dtype
=
torch
.
int
,
self
.
args
.
batch_size
,
device
=
self
.
args
.
device
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
)
dtype
=
torch
.
int
,
else
:
device
=
self
.
args
.
device
,
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
@@ -168,6 +188,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -168,6 +188,7 @@ class KTransformersInterface(TransformersInterface):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -179,4 +200,4 @@ class KTransformersInterface(TransformersInterface):
...
@@ -179,4 +200,4 @@ class KTransformersInterface(TransformersInterface):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
yield
v
yield
v
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
bb0ccc7b
...
@@ -198,14 +198,28 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -198,14 +198,28 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
temperature
=
self
.
args
.
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
)
)
for
token_idx
in
self
.
ever_generated_ids
:
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
if
logits
[
token_idx
]
<
0
:
logits
=
self
.
logits_warper
(
self
.
inputs
.
view
(
1
,
-
1
),
logits
.
view
(
1
,
-
1
))
logits
[
token_idx
]
*=
self
.
args
.
repetition_penalty
else
:
logits
[
token_idx
]
/=
self
.
args
.
repetition_penalty
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
@@ -239,31 +253,51 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -239,31 +253,51 @@ class TransformersInterface(BackendInterfaceBase):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
generated_ids
=
torch
.
zeros
(
dtype
=
torch
.
int
,
self
.
args
.
batch_size
,
device
=
self
.
args
.
device
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
)
dtype
=
torch
.
int
,
else
:
device
=
self
.
args
.
device
,
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
@@ -285,6 +319,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -285,6 +319,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -315,6 +350,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -315,6 +350,7 @@ class TransformersInterface(BackendInterfaceBase):
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
...
@@ -325,7 +361,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -325,7 +361,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\
\
n"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
[
input_ids
,
token_thinks
],
dim
=
1
)
)
...
@@ -333,11 +369,14 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -333,11 +369,14 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
t
=
"<think>
\n
"
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
# output think token after prefill done
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
...
...
ktransformers/server/main.py
View file @
bb0ccc7b
...
@@ -105,6 +105,10 @@ def custom_openapi(app):
...
@@ -105,6 +105,10 @@ def custom_openapi(app):
def
main
():
def
main
():
cfg
=
Config
()
cfg
=
Config
()
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
cfg
.
use_cuda_graph
=
False
arg_parser
=
ArgumentParser
(
cfg
)
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
# 初始化消息
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment