Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
e5694f91
Unverified
Commit
e5694f91
authored
Mar 10, 2025
by
Yuhao Tsui
Committed by
GitHub
Mar 10, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
d050d865
09c043d8
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
353 additions
and
160 deletions
+353
-160
ktransformers/__init__.py
ktransformers/__init__.py
+1
-1
ktransformers/ktransformers_ext/CMakeLists.txt
ktransformers/ktransformers_ext/CMakeLists.txt
+1
-0
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+143
-54
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+26
-16
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+78
-11
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+14
-6
ktransformers/server/backend/base.py
ktransformers/server/backend/base.py
+11
-6
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+10
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+15
-8
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-0
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+22
-42
ktransformers/tests/AIME_2024/eval_api.py
ktransformers/tests/AIME_2024/eval_api.py
+9
-5
ktransformers/tests/humaneval/eval_api.py
ktransformers/tests/humaneval/eval_api.py
+17
-7
ktransformers/tests/humaneval/evaluation.py
ktransformers/tests/humaneval/evaluation.py
+1
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+1
-1
third_party/llamafile/iqk_mul_mat.inc
third_party/llamafile/iqk_mul_mat.inc
+2
-1
No files found.
ktransformers/__init__.py
View file @
e5694f91
...
...
@@ -8,4 +8,4 @@ Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2025-02-15 03:53:02
'''
__version__
=
"0.2.3"
\ No newline at end of file
__version__
=
"0.2.3.post1"
ktransformers/ktransformers_ext/CMakeLists.txt
View file @
e5694f91
...
...
@@ -175,6 +175,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
list
(
APPEND ARCH_FLAGS -mavx512bw
)
list
(
APPEND ARCH_FLAGS -mavx512dq
)
list
(
APPEND ARCH_FLAGS -mavx512vnni
)
list
(
APPEND ARCH_FLAGS -mavx512vpopcntdq
)
endif
()
if
(
LLAMA_AVX512_BF16
)
list
(
APPEND ARCH_FLAGS -mavx512bf16
)
...
...
ktransformers/operators/attention.py
View file @
e5694f91
...
...
@@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
,
attention_ref
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
logger
=
logging
.
getLogger
(
"attention"
)
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
e5694f91
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.
2
Version : 0.2.
3
'''
import
torch
import
os
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
flashinfer_enabled
=
False
...
...
@@ -17,7 +19,7 @@ except ImportError:
import
math
def
attention_ref
(
def
attention_ref
_torch
(
batch_size
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -139,11 +141,6 @@ class MLAWrapper():
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
...
...
@@ -203,20 +200,58 @@ class MLAWrapperSingleton():
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
def
checksame
():
flashinfer_folder
=
"./flashinfer_output"
flashinfer_folder
=
"./kv_cache_flashinfer"
triton_folder
=
"./triton_output"
triton_folder
=
"./kv_cache_triton"
max_layer_id
=
1
max_forward_id
=
2
for
forward_id
in
range
(
0
,
19
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
max_layer_id
):
print
(
layer_id
)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name
=
f
"layer_
{
layer_id
}
.pt"
flashinfer_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
triton_path
=
os
.
path
.
join
(
triton_folder
,
file_name
)
if
not
os
.
path
.
exists
(
triton_path
):
print
(
f
"
{
file_name
}
not exist in
{
triton_folder
}
"
)
continue
if
not
os
.
path
.
exists
(
flashinfer_path
):
print
(
f
"
{
file_name
}
not exist in
{
flashinfer_folder
}
"
)
continue
flashinfer_tensor
=
torch
.
load
(
flashinfer_path
)[
1
:
2
,
:
62
]
#
triton_tensor
=
torch
.
load
(
triton_path
)[
1
:
2
,
:
62
]
#.squeeze(1)#
try
:
torch
.
testing
.
assert_close
(
flashinfer_tensor
,
triton_tensor
,
rtol
=
1e-9
,
atol
=
1e-9
)
except
AssertionError
as
e
:
print
(
e
)
if
__name__
==
"__main__"
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
#checksame()
#exit(0)
max_batch_size
=
1
max_pages
=
128
max_pages
=
64
page_size
=
64
num_heads
=
128
# warm-up
kv_len
=
4023
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
c
kv
=
torch
.
randn
((
max_pages
,
page_size
,
5
12
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
)
,
d
type
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
kv
_buf
=
torch
.
randn
((
max_pages
,
page_size
,
5
76
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
,
k_pe
=
torch
.
split
(
kv_buf
,
[
512
,
64
]
,
d
im
=-
1
)
wrapper
=
MLAWrapperSingleton
.
get_instance
(
...
...
@@ -241,18 +276,41 @@ if __name__ == "__main__":
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope
_buf
,
q_pe
_buf
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope_buf
,
q_pe_buf
,
ckv
,
k_pe
)
# warm-up finished
for
forward_id
in
range
(
0
,
1
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
1
):
print
(
layer_id
)
flashinfer_folder
=
"./kv_cache_flashinfer"
forward_id
=
17
layer_id
=
0
file_name
=
f
"layer_
{
layer_id
}
.pt"
kv_cache_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
flashinfer_folder
=
"./flashinfer_output"
q_len
=
1
kv_len
=
126
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_nope.pt"
q_nope
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
512
).
to
(
device
=
"cuda"
)
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_pe.pt"
q_pe
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
64
).
to
(
device
=
"cuda"
)
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
kv_cache
=
torch
.
load
(
kv_cache_path
).
to
(
device
=
"cuda"
)
pages
,
page_size
,
_
,
head_dim
=
kv_cache
.
shape
kv_cache
=
kv_cache
.
view
(
pages
,
page_size
,
head_dim
)
ckv
,
k_pe
=
torch
.
split
(
kv_cache
,
[
512
,
64
],
dim
=-
1
)
kv_len
=
6789
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
None
,
None
,
None
,
kv_len_arr
,
...
...
@@ -265,27 +323,58 @@ if __name__ == "__main__":
torch
.
bfloat16
,
)
q_nope_buf
.
copy_
(
q_nope
)
q_pe_buf
.
copy_
(
q_pe
)
kv_buf
[:
pages
].
copy_
(
kv_cache
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
# ref_torch
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
print
(
k
[:
kv_len
].
shape
)
print
(
v
[:
kv_len
].
shape
)
attn_ref
,
lse_ref
=
attention_ref
(
attn_ref
,
lse_ref
=
attention_ref_torch
(
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
,
q
,
k
[:
kv_len
],
v
[:
kv_len
],
Tru
e
,
Fals
e
,
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
# ref_triton
attn_logits
=
torch
.
empty
(
(
max_batch_size
,
num_heads
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
512
+
1
,
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
triton_ref
=
torch
.
zeros_like
(
q_nope
)
page_table
=
torch
.
arange
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ckv_with_pe
=
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
).
contiguous
().
view
(
pages
,
page_size
,
1
,
576
)
ckv
=
ckv
.
view
(
pages
,
page_size
,
1
,
512
)
decode_attention_fwd_grouped
(
q
,
ckv_with_pe
,
ckv
,
triton_ref
,
page_table
,
kv_len_arr
,
attn_logits
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
192
**
(
-
0.5
),
page_size
)
torch
.
testing
.
assert_close
(
attn_output
,
triton_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print
(
"test past"
)
ktransformers/server/api/ollama/completions.py
View file @
e5694f91
...
...
@@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
...
...
@@ -61,7 +63,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if
input
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
async
for
res
in
interface
.
inference
(
input
.
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
...
...
@@ -142,7 +148,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
token
in
interface
.
inference
(
prompt
,
id
):
async
for
res
in
interface
.
inference
(
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
e5694f91
...
...
@@ -5,10 +5,16 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
,
Usage
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
router
=
APIRouter
()
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
...
...
@@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
chunk
.
set_token
(
token
)
chunk
=
ChatCompletionChunk
(
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
return
chat_stream_response
(
request
,
inner
())
else
:
from
openai.types.chat.chat_completion
import
Choice
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
return
comp
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
ktransformers/server/api/openai/legacy/completions.py
View file @
e5694f91
...
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
()
...
...
@@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate):
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
if
create
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
...
@@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
comp
.
append_token
(
token
)
return
comp
ktransformers/server/backend/base.py
View file @
e5694f91
...
...
@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from
ktransformers.server.schemas.assistants.messages
import
MessageCreate
,
MessageObject
,
Role
from
ktransformers.server.schemas.assistants.runs
import
RunObject
from
ktransformers.server.schemas.assistants.threads
import
ThreadObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.schemas.base
import
ObjectID
,
Order
from
ktransformers.server.utils.multi_timer
import
Profiler
...
...
@@ -142,7 +143,11 @@ class ThreadContext:
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
async
for
token
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
async
for
res
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
e5694f91
...
...
@@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
warm_uped
=
False
...
...
@@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
# return this inference raw usage
yield
RawUsage
(
tokenize_time
=
self
.
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
self
.
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
self
.
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
self
.
profiler
.
get_counter
(
'prefill'
),
decode_count
=
self
.
profiler
.
get_counter
(
'decode'
),
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
e5694f91
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
yield
self
.
streamer
.
end
()
,
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
...
...
@@ -344,14 +344,21 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
(
self
.
model
.
config
.
qk_rope_head_dim
+
self
.
model
.
config
.
qk_nope_head_dim
)
**
(
-
0.5
)
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
yield
self
.
streamer
.
end
(),
None
yield
""
,
"stop"
assert
self
.
args
.
batch_size
==
1
break
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
streamer
.
end
()
yield
self
.
append_new_tokens
(
next_token
),
None
else
:
# for's else, if output get max new tokens
yield
self
.
streamer
.
end
(),
None
yield
""
,
"length"
def
check_is_new
(
self
,
thread_id
:
str
):
if
not
self
.
use_static_cache
:
...
...
@@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/requirements.txt
View file @
e5694f91
...
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
e5694f91
from
typing
import
List
,
Optional
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
from
ktransformers.server.schemas.base
import
Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
class
Role
(
Enum
):
system
=
'system'
user
=
'user'
...
...
@@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
FinishReason
(
Enum
):
stop
=
'stop'
length
=
'length'
class
Choice
(
BaseModel
):
index
:
int
message
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
DeltaChoice
(
BaseModel
):
index
:
int
delta
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
Usage
(
BaseModel
):
completion_tokens
:
int
prompt_tokens
:
int
total_tokens
:
int
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
choices
:
List
[
Choice
]
created
:
int
model
:
str
object
:
Literal
[
"chat.completion.chunk"
]
service_tier
:
Optional
[
Literal
[
"scale"
,
"default"
]]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
class
ChatCompletionBase
(
Object
):
created
:
int
model
:
str
=
'not implmented'
system_fingerprint
:
str
=
'not implmented'
usage
:
Optional
[
Usage
]
=
None
class
ChatCompletionObject
(
ChatCompletionBase
):
choices
:
List
[
Choice
]
=
[]
def
append_token
(
self
,
token
:
str
):
if
len
(
self
.
choices
)
==
0
:
self
.
choices
.
append
(
Choice
(
index
=
0
,
message
=
Message
(
content
=
''
,
role
=
Role
.
assistant
)))
self
.
choices
[
0
].
message
.
content
+=
token
class
ChatCompletionChunk
(
ChatCompletionBase
):
choices
:
List
[
DeltaChoice
]
=
[]
def
set_token
(
self
,
token
:
str
):
self
.
choices
=
[
DeltaChoice
(
index
=
0
,
delta
=
Message
(
content
=
token
,
role
=
Role
.
assistant
))
]
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
prefill_time
:
float
decode_time
:
float
prefill_count
:
int
decode_count
:
int
ktransformers/tests/AIME_2024/eval_api.py
View file @
e5694f91
...
...
@@ -78,13 +78,15 @@ def run_eval_api(
format_tabs
:
bool
=
False
,
auth_token
:
str
=
None
,
problem_file
:
str
=
None
,
append
:
bool
=
False
append
:
bool
=
False
,
skip
:
int
=
0
):
data
=
load_data
(
problem_file
)
pbar
=
tqdm
.
tqdm
(
total
=
len
(
data
)
*
1
)
pbar
.
update
(
skip
)
for
i
in
range
(
len
(
data
)):
i
=
i
+
skip
data_item
=
data
[
i
]
question
=
data_item
[
'Problem'
]
# Start the timer for this evaluation
...
...
@@ -97,6 +99,7 @@ def run_eval_api(
score
=
get_score
(
completion
,
answer
)
elapsed_time
=
time
.
time
()
-
start_time
result
=
{
"index"
:
i
,
"question_id"
:
data_item
[
"ID"
],
"answer"
:
answer
,
"prediction"
:
completion
,
...
...
@@ -114,9 +117,9 @@ def run_eval_api(
pbar
.
update
(
1
)
def
main
(
output_path
,
api_url
,
model_name
,
auth_token
,
format_tabs
,
problem_file
,
append
):
def
main
(
output_path
,
api_url
,
model_name
,
auth_token
,
format_tabs
,
problem_file
,
append
,
skip
):
os
.
makedirs
(
os
.
path
.
dirname
(
output_path
),
exist_ok
=
True
)
run_eval_api
(
api_url
,
model_name
,
output_path
,
format_tabs
,
auth_token
,
problem_file
,
append
)
run_eval_api
(
api_url
,
model_name
,
output_path
,
format_tabs
,
auth_token
,
problem_file
,
append
,
skip
)
if
__name__
==
"__main__"
:
...
...
@@ -128,6 +131,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--format_tabs"
,
action
=
"store_true"
,
help
=
"Format Tabs"
)
parser
.
add_argument
(
"--problem_file"
,
type
=
str
,
default
=
"Maxwell-Jia/AIME_2024"
,
help
=
"Evalset File"
)
parser
.
add_argument
(
"--no_append"
,
action
=
"store_false"
,
help
=
"Append to existing file"
)
parser
.
add_argument
(
"--skip"
,
type
=
int
,
default
=
0
,
help
=
"Skip some tasks"
)
args
=
parser
.
parse_args
()
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
main
(
args
.
out_path
,
args
.
api_url
,
args
.
model_name
,
args
.
auth_token
,
args
.
format_tabs
,
args
.
problem_file
,
args
.
no_append
)
\ No newline at end of file
main
(
args
.
out_path
,
args
.
api_url
,
args
.
model_name
,
args
.
auth_token
,
args
.
format_tabs
,
args
.
problem_file
,
args
.
no_append
,
args
.
skip
)
\ No newline at end of file
ktransformers/tests/humaneval/eval_api.py
View file @
e5694f91
...
...
@@ -39,7 +39,8 @@ def run_eval_api(
format_tabs
:
bool
=
False
,
auth_token
:
str
=
None
,
problem_file
:
str
=
None
,
append
:
bool
=
False
append
:
bool
=
False
,
skip
:
int
=
0
):
if
(
problem_file
is
None
):
problems
=
read_problems
()
...
...
@@ -47,8 +48,14 @@ def run_eval_api(
problems
=
read_problems
(
problem_file
)
samples
=
[]
pbar
=
tqdm
.
tqdm
(
total
=
len
(
problems
)
*
1
)
pbar
.
update
(
skip
)
try
:
for
task_id
in
problems
:
# skip some tasks
if
skip
>
0
:
skip
-=
1
continue
if
format_tabs
:
prompt
=
problems
[
task_id
][
"prompt"
].
replace
(
" "
,
"
\t
"
)
else
:
...
...
@@ -67,23 +74,26 @@ def run_eval_api(
if
not
append
:
write_jsonl
(
out_path
,
samples
,
append
=
append
)
except
Exception
as
e
:
if
not
append
:
write_jsonl
(
out_path
,
samples
,
append
=
append
)
print
(
f
"Error:
{
e
}
"
)
def
main
(
output_path
,
api_url
,
model_name
,
auth_token
,
format_tabs
,
problem_file
,
append
):
def
main
(
output_path
,
api_url
,
model_name
,
auth_token
,
format_tabs
,
problem_file
,
append
,
skip
):
os
.
makedirs
(
os
.
path
.
dirname
(
output_path
),
exist_ok
=
True
)
run_eval_api
(
api_url
,
model_name
,
output_path
,
format_tabs
,
auth_token
,
problem_file
,
append
)
run_eval_api
(
api_url
,
model_name
,
output_path
,
format_tabs
,
auth_token
,
problem_file
,
append
,
skip
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"https://api.siliconflow.cn/v1/chat/completions"
,
help
=
"API URL"
)
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model Name"
)
parser
.
add_argument
(
"--out_path"
,
type
=
str
,
default
=
"results/api/eval.jsonl"
,
help
=
"Output Path"
)
parser
.
add_argument
(
"--out_path"
,
type
=
str
,
default
=
"results/api/eval
_b
.jsonl"
,
help
=
"Output Path"
)
parser
.
add_argument
(
"--auth_token"
,
type
=
str
,
default
=
None
,
help
=
"Auth Token"
)
parser
.
add_argument
(
"--format_tabs"
,
action
=
"store_true"
,
help
=
"Format Tabs"
)
parser
.
add_argument
(
"--problem_file"
,
type
=
str
,
default
=
None
,
help
=
"Evalset File"
)
parser
.
add_argument
(
"--no_append"
,
action
=
"store_false"
,
help
=
"Append to existing file"
)
parser
.
add_argument
(
"--skip"
,
type
=
int
,
default
=
0
,
help
=
"Skip first n problems"
)
args
=
parser
.
parse_args
()
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
main
(
args
.
out_path
,
args
.
api_url
,
args
.
model_name
,
args
.
auth_token
,
args
.
format_tabs
,
args
.
problem_file
,
args
.
no_append
)
\ No newline at end of file
main
(
args
.
out_path
,
args
.
api_url
,
args
.
model_name
,
args
.
auth_token
,
args
.
format_tabs
,
args
.
problem_file
,
args
.
no_append
,
args
.
skip
)
\ No newline at end of file
ktransformers/tests/humaneval/evaluation.py
View file @
e5694f91
...
...
@@ -8,7 +8,7 @@ def filter_code(completion: str) -> str:
completion
=
completion
.
split
(
'if __name__ == "__main__":'
)[
0
]
if
"# Example usage"
in
completion
:
completion
=
completion
.
split
(
"# Example usage"
)[
0
]
return
completion
.
split
(
"
\n\n
"
)[
0
]
return
completion
def
fix_indents
(
text
:
str
)
->
str
:
...
...
ktransformers/util/utils.py
View file @
e5694f91
...
...
@@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
)
,
torch
.
bfloat16
,
torch
.
bfloat16
)
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
...
...
third_party/llamafile/iqk_mul_mat.inc
View file @
e5694f91
...
...
@@ -2388,7 +2388,8 @@ struct SimpleBits {
struct
EvenSignHelper
{
#ifdef HAVE_FANCY_SIMD
#if defined HAVE_FANCY_SIMD
// #pragma message("Using AVX512VPOPCNTDQ in even sign helper")
union
sbits_t
{
__m128i
vec
;
__mmask32
mask
[
4
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment