Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
cf4da5fd
Unverified
Commit
cf4da5fd
authored
Feb 19, 2025
by
Atream
Committed by
GitHub
Feb 19, 2025
Browse files
Merge pull request #382 from ceerRep/server-prefix-cache
fix server and add prefix cache for server
parents
ea75849d
584c7d56
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
192 additions
and
89 deletions
+192
-89
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+13
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+46
-22
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+5
-7
ktransformers/server/args.py
ktransformers/server/args.py
+2
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+47
-24
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+78
-34
ktransformers/server/main.py
ktransformers/server/main.py
+1
-0
No files found.
ktransformers/models/custom_cache.py
View file @
cf4da5fd
...
...
@@ -172,7 +172,19 @@ class StaticCache(transformers.StaticCache):
self
.
key_cache
[
layer_idx
].
zero_
()
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
0
def
remove_suffix
(
self
,
start_pos
):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
# In-place ops prevent breaking the static address
if
self
.
is_MLA
:
k_cache
=
self
.
key_cache
[
layer_idx
]
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
1
])[
start_pos
:].
zero_
()
else
:
self
.
key_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
value_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
start_pos
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
\ No newline at end of file
return
self
.
max_cache_len
ktransformers/operators/attention.py
View file @
cf4da5fd
...
...
@@ -129,8 +129,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if
hasattr
(
self
.
orig_module
,
'kv_b_proj'
):
del
self
.
orig_module
.
kv_b_proj
#
if hasattr(self.orig_module, 'kv_b_proj'):
#
del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
...
...
@@ -222,6 +222,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
...
...
@@ -293,26 +303,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
k_pe
=
k_pe
[:,
:
q_len
]
compressed_kv
=
compressed_kv
[:,
:
q_len
]
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
...
...
@@ -362,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
...
...
@@ -441,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
k_pe
=
k_pe
[:,
:
q_len
]
compressed_kv
=
compressed_kv
[:,
:
q_len
]
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
cf4da5fd
...
...
@@ -5,18 +5,15 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
,
Usage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
router
=
APIRouter
()
models
=
[
{
"id"
:
"0"
,
"name"
:
"ktranformers-model"
},
]
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
async
def
list_models
():
return
models
return
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}]
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
...
...
@@ -36,7 +33,8 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
yield
chunk
return
chat_stream_response
(
request
,
inner
())
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
):
comp
.
append_token
(
token
)
return
comp
ktransformers/server/args.py
View file @
cf4da5fd
...
...
@@ -90,7 +90,8 @@ class ArgumentParser:
# user config
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
cf4da5fd
...
...
@@ -15,7 +15,9 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
...
...
@@ -74,13 +76,13 @@ class KTransformersInterface(TransformersInterface):
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
global
warm_uped
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
global
warm_uped
torch
.
cuda
.
set_device
(
torch_device
)
if
self
.
args
.
use_cuda_graph
and
warm_uped
==
True
:
if
warm_uped
and
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
...
...
@@ -127,34 +129,54 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
same_prefix
=
0
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
...
@@ -176,6 +198,7 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
@@ -187,4 +210,4 @@ class KTransformersInterface(TransformersInterface):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
yield
v
\ No newline at end of file
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
cf4da5fd
...
...
@@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
m
[
"content"
]
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
...
...
@@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
...
...
@@ -198,14 +202,28 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
temperature
=
self
.
args
.
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
)
)
for
token_idx
in
self
.
ever_generated_ids
:
if
logits
[
token_idx
]
<
0
:
logits
[
token_idx
]
*=
self
.
args
.
repetition_penalty
else
:
logits
[
token_idx
]
/=
self
.
args
.
repetition_penalty
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
self
.
logits_warper
(
self
.
inputs
.
view
(
1
,
-
1
),
logits
.
view
(
1
,
-
1
))
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
...
@@ -239,31 +257,51 @@ class TransformersInterface(BackendInterfaceBase):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
same_prefix
=
0
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
...
@@ -285,6 +323,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
@@ -321,6 +360,7 @@ class TransformersInterface(BackendInterfaceBase):
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
...
...
@@ -330,8 +370,9 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\
\
n"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
)
...
...
@@ -339,11 +380,14 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
t
=
"<think>
\n
"
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
# output think token after prefill done
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
...
...
ktransformers/server/main.py
View file @
cf4da5fd
...
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
def
main
():
cfg
=
Config
()
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment