Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
299c4dca
Commit
299c4dca
authored
Mar 07, 2025
by
BITcyman
Browse files
[update] support openai chat completion api
parent
63b1c852
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
166 additions
and
83 deletions
+166
-83
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+26
-16
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+78
-11
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+14
-6
ktransformers/server/backend/base.py
ktransformers/server/backend/base.py
+1
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+10
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+14
-7
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-0
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+22
-42
No files found.
ktransformers/server/api/ollama/completions.py
View file @
299c4dca
...
...
@@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
...
...
@@ -58,14 +60,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if
input
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
async
for
res
in
interface
.
inference
(
input
.
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
...
...
@@ -123,14 +129,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
token
in
interface
.
inference
(
prompt
,
id
):
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{
"role"
:
"assistant"
,
"content"
:
token
},
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
async
for
res
in
interface
.
inference
(
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{
"role"
:
"assistant"
,
"content"
:
token
},
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# 计算性能数据
end_time
=
time
()
total_duration
=
int
((
end_time
-
start_time
)
*
1_000_000_000
)
# 转换为纳秒
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
299c4dca
...
...
@@ -5,10 +5,16 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
,
Usage
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
router
=
APIRouter
()
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
...
...
@@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
chunk
.
set_token
(
token
)
yield
chunk
return
chat_stream_response
(
request
,
inner
())
chunk
=
ChatCompletionChunk
(
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
return
comp
from
openai.types.chat.chat_completion
import
Choice
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
else
:
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
ktransformers/server/api/openai/legacy/completions.py
View file @
299c4dca
...
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
()
...
...
@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
if
create
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
comp
.
append_token
(
token
)
return
comp
ktransformers/server/backend/base.py
View file @
299c4dca
...
...
@@ -142,7 +142,7 @@ class ThreadContext:
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
async
for
token
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
async
for
token
,
finish_reason
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
299c4dca
...
...
@@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
warm_uped
=
False
...
...
@@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
# return this inference raw usage
yield
RawUsage
(
tokenize_time
=
self
.
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
self
.
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
self
.
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
self
.
profiler
.
get_counter
(
'prefill'
),
decode_count
=
self
.
profiler
.
get_counter
(
'decode'
),
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
299c4dca
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
yield
self
.
streamer
.
end
()
,
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
...
...
@@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
yield
self
.
streamer
.
end
(),
None
yield
""
,
"stop"
assert
self
.
args
.
batch_size
==
1
break
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
streamer
.
end
()
yield
self
.
append_new_tokens
(
next_token
),
None
else
:
# for's else, if output get max new tokens
yield
self
.
streamer
.
end
(),
None
yield
""
,
"length"
def
check_is_new
(
self
,
thread_id
:
str
):
if
not
self
.
use_static_cache
:
...
...
@@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/requirements.txt
View file @
299c4dca
...
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
299c4dca
from
typing
import
List
,
Optional
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
from
ktransformers.server.schemas.base
import
Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
class
Role
(
Enum
):
system
=
'system'
user
=
'user'
...
...
@@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
FinishReason
(
Enum
):
stop
=
'stop'
length
=
'length'
class
Choice
(
BaseModel
):
index
:
int
message
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
DeltaChoice
(
BaseModel
):
index
:
int
delta
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
Usage
(
BaseModel
):
completion_tokens
:
int
prompt_tokens
:
int
total_tokens
:
int
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
choices
:
List
[
Choice
]
created
:
int
model
:
str
object
:
Literal
[
"chat.completion.chunk"
]
service_tier
:
Optional
[
Literal
[
"scale"
,
"default"
]]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
class
ChatCompletionBase
(
Object
):
created
:
int
model
:
str
=
'not implmented'
system_fingerprint
:
str
=
'not implmented'
usage
:
Optional
[
Usage
]
=
None
class
ChatCompletionObject
(
ChatCompletionBase
):
choices
:
List
[
Choice
]
=
[]
def
append_token
(
self
,
token
:
str
):
if
len
(
self
.
choices
)
==
0
:
self
.
choices
.
append
(
Choice
(
index
=
0
,
message
=
Message
(
content
=
''
,
role
=
Role
.
assistant
)))
self
.
choices
[
0
].
message
.
content
+=
token
class
ChatCompletionChunk
(
ChatCompletionBase
):
choices
:
List
[
DeltaChoice
]
=
[]
def
set_token
(
self
,
token
:
str
):
self
.
choices
=
[
DeltaChoice
(
index
=
0
,
delta
=
Message
(
content
=
token
,
role
=
Role
.
assistant
))
]
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
prefill_time
:
float
decode_time
:
float
prefill_count
:
int
decode_count
:
int
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment