Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
cf4da5fd
Unverified
Commit
cf4da5fd
authored
Feb 19, 2025
by
Atream
Committed by
GitHub
Feb 19, 2025
Browse files
Merge pull request #382 from ceerRep/server-prefix-cache
fix server and add prefix cache for server
parents
ea75849d
584c7d56
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
192 additions
and
89 deletions
+192
-89
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+13
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+46
-22
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+5
-7
ktransformers/server/args.py
ktransformers/server/args.py
+2
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+47
-24
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+78
-34
ktransformers/server/main.py
ktransformers/server/main.py
+1
-0
No files found.
ktransformers/models/custom_cache.py
View file @
cf4da5fd
...
@@ -172,6 +172,18 @@ class StaticCache(transformers.StaticCache):
...
@@ -172,6 +172,18 @@ class StaticCache(transformers.StaticCache):
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
key_cache
[
layer_idx
].
zero_
()
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
0
def
remove_suffix
(
self
,
start_pos
):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
# In-place ops prevent breaking the static address
if
self
.
is_MLA
:
k_cache
=
self
.
key_cache
[
layer_idx
]
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
1
])[
start_pos
:].
zero_
()
else
:
self
.
key_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
value_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
start_pos
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
"""Returns the maximum shape of the cache."""
...
...
ktransformers/operators/attention.py
View file @
cf4da5fd
...
@@ -129,8 +129,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -129,8 +129,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if
hasattr
(
self
.
orig_module
,
'kv_b_proj'
):
#
if hasattr(self.orig_module, 'kv_b_proj'):
del
self
.
orig_module
.
kv_b_proj
#
del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
...
@@ -223,6 +223,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -223,6 +223,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
...
@@ -293,26 +303,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -293,26 +303,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
.
unsqueeze
(
0
)
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
[:,
:
q_len
]
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
[:,
:
q_len
]
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
attn_output
=
flash_attn_func
(
...
@@ -363,6 +375,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -363,6 +375,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
...
@@ -441,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -441,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
.
unsqueeze
(
0
)
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
[:,
:
q_len
]
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
[:,
:
q_len
]
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
attn_output
=
flash_attn_func
(
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
cf4da5fd
...
@@ -5,18 +5,15 @@ from fastapi import APIRouter
...
@@ -5,18 +5,15 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
,
Usage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
router
=
APIRouter
()
router
=
APIRouter
()
models
=
[
{
"id"
:
"0"
,
"name"
:
"ktranformers-model"
},
]
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
async
def
list_models
():
async
def
list_models
():
return
models
return
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}]
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
...
@@ -36,7 +33,8 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
...
@@ -36,7 +33,8 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
yield
chunk
yield
chunk
return
chat_stream_response
(
request
,
inner
())
return
chat_stream_response
(
request
,
inner
())
else
:
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
):
async
for
token
in
interface
.
inference
(
input_message
,
id
):
comp
.
append_token
(
token
)
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/args.py
View file @
cf4da5fd
...
@@ -90,7 +90,8 @@ class ArgumentParser:
...
@@ -90,7 +90,8 @@ class ArgumentParser:
# user config
# user config
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# web config
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
cf4da5fd
...
@@ -15,7 +15,9 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
...
@@ -15,7 +15,9 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
warm_uped
=
False
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
pass
...
@@ -74,13 +76,13 @@ class KTransformersInterface(TransformersInterface):
...
@@ -74,13 +76,13 @@ class KTransformersInterface(TransformersInterface):
self
.
_infer_lock
=
asyncio
.
Lock
()
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
global
warm_uped
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
global
warm_uped
torch
.
cuda
.
set_device
(
torch_device
)
torch
.
cuda
.
set_device
(
torch_device
)
if
self
.
args
.
use_cuda_graph
and
warm_uped
==
True
:
if
warm_uped
and
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
self
.
cuda_graph_runner
.
capture
(
...
@@ -127,24 +129,43 @@ class KTransformersInterface(TransformersInterface):
...
@@ -127,24 +129,43 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
device
=
"cuda:0"
if
device
==
"cuda"
else
device
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
device
=
self
.
args
.
device
,
)
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
self
.
seq_length
+=
input_ids_length
...
@@ -155,6 +176,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -155,6 +176,7 @@ class KTransformersInterface(TransformersInterface):
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
@@ -176,6 +198,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -176,6 +198,7 @@ class KTransformersInterface(TransformersInterface):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
cf4da5fd
...
@@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
for
m
in
messages
[
1
:]:
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
m
[
"content"
]
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
else
:
new_messages
.
append
(
m
)
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
...
@@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
...
@@ -198,14 +202,28 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -198,14 +202,28 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
temperature
=
self
.
args
.
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
)
)
for
token_idx
in
self
.
ever_generated_ids
:
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
if
logits
[
token_idx
]
<
0
:
logits
=
self
.
logits_warper
(
self
.
inputs
.
view
(
1
,
-
1
),
logits
.
view
(
1
,
-
1
))
logits
[
token_idx
]
*=
self
.
args
.
repetition_penalty
else
:
logits
[
token_idx
]
/=
self
.
args
.
repetition_penalty
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
@@ -239,21 +257,40 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -239,21 +257,40 @@ class TransformersInterface(BackendInterfaceBase):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
device
=
self
.
args
.
device
,
)
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
self
.
seq_length
+=
input_ids_length
...
@@ -264,6 +301,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -264,6 +301,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
...
@@ -285,6 +323,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -285,6 +323,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -321,6 +360,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -321,6 +360,7 @@ class TransformersInterface(BackendInterfaceBase):
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
...
@@ -330,8 +370,9 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -330,8 +370,9 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\
\
n"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
[
input_ids
,
token_thinks
],
dim
=
1
)
)
...
@@ -339,11 +380,14 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -339,11 +380,14 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
t
=
"<think>
\n
"
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
# output think token after prefill done
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
...
...
ktransformers/server/main.py
View file @
cf4da5fd
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
def
main
():
def
main
():
cfg
=
Config
()
cfg
=
Config
()
arg_parser
=
ArgumentParser
(
cfg
)
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
# 初始化消息
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment