Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
96d75d53
Unverified
Commit
96d75d53
authored
Mar 07, 2025
by
wang jiahao
Committed by
GitHub
Mar 07, 2025
Browse files
Merge pull request #835 from BITcyman/fix-openai_chat_completion
[fix] support openai chat completion api
parents
63b1c852
299c4dca
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
166 additions
and
83 deletions
+166
-83
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+26
-16
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+78
-11
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+14
-6
ktransformers/server/backend/base.py
ktransformers/server/backend/base.py
+1
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+10
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+14
-7
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-0
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+22
-42
No files found.
ktransformers/server/api/ollama/completions.py
View file @
96d75d53
...
...
@@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
...
...
@@ -58,7 +60,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if
input
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
async
for
res
in
interface
.
inference
(
input
.
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
...
...
@@ -123,7 +129,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
token
in
interface
.
inference
(
prompt
,
id
):
async
for
res
in
interface
.
inference
(
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
96d75d53
...
...
@@ -5,10 +5,16 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
,
Usage
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
router
=
APIRouter
()
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
...
...
@@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
chunk
.
set_token
(
token
)
chunk
=
ChatCompletionChunk
(
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
return
chat_stream_response
(
request
,
inner
())
else
:
from
openai.types.chat.chat_completion
import
Choice
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
return
comp
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
ktransformers/server/api/openai/legacy/completions.py
View file @
96d75d53
...
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
()
...
...
@@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate):
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
if
create
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
...
@@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
comp
.
append_token
(
token
)
return
comp
ktransformers/server/backend/base.py
View file @
96d75d53
...
...
@@ -142,7 +142,7 @@ class ThreadContext:
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
async
for
token
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
async
for
token
,
finish_reason
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
96d75d53
...
...
@@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
warm_uped
=
False
...
...
@@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
# return this inference raw usage
yield
RawUsage
(
tokenize_time
=
self
.
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
self
.
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
self
.
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
self
.
profiler
.
get_counter
(
'prefill'
),
decode_count
=
self
.
profiler
.
get_counter
(
'decode'
),
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
96d75d53
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
yield
self
.
streamer
.
end
()
,
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
...
...
@@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
yield
self
.
streamer
.
end
(),
None
yield
""
,
"stop"
assert
self
.
args
.
batch_size
==
1
break
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
streamer
.
end
()
yield
self
.
append_new_tokens
(
next_token
),
None
else
:
# for's else, if output get max new tokens
yield
self
.
streamer
.
end
(),
None
yield
""
,
"length"
def
check_is_new
(
self
,
thread_id
:
str
):
if
not
self
.
use_static_cache
:
...
...
@@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/requirements.txt
View file @
96d75d53
...
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
96d75d53
from
typing
import
List
,
Optional
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
from
ktransformers.server.schemas.base
import
Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
class
Role
(
Enum
):
system
=
'system'
user
=
'user'
...
...
@@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
FinishReason
(
Enum
):
stop
=
'stop'
length
=
'length'
class
Choice
(
BaseModel
):
index
:
int
message
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
DeltaChoice
(
BaseModel
):
index
:
int
delta
:
Message
logprobs
:
Optional
[
str
]
=
None
finish_reason
:
FinishReason
=
None
class
Usage
(
BaseModel
):
completion_tokens
:
int
prompt_tokens
:
int
total_tokens
:
int
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
choices
:
List
[
Choice
]
created
:
int
model
:
str
object
:
Literal
[
"chat.completion.chunk"
]
service_tier
:
Optional
[
Literal
[
"scale"
,
"default"
]]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
class
ChatCompletionBase
(
Object
):
created
:
int
model
:
str
=
'not implmented'
system_fingerprint
:
str
=
'not implmented'
usage
:
Optional
[
Usage
]
=
None
class
ChatCompletionObject
(
ChatCompletionBase
):
choices
:
List
[
Choice
]
=
[]
def
append_token
(
self
,
token
:
str
):
if
len
(
self
.
choices
)
==
0
:
self
.
choices
.
append
(
Choice
(
index
=
0
,
message
=
Message
(
content
=
''
,
role
=
Role
.
assistant
)))
self
.
choices
[
0
].
message
.
content
+=
token
class
ChatCompletionChunk
(
ChatCompletionBase
):
choices
:
List
[
DeltaChoice
]
=
[]
def
set_token
(
self
,
token
:
str
):
self
.
choices
=
[
DeltaChoice
(
index
=
0
,
delta
=
Message
(
content
=
token
,
role
=
Role
.
assistant
))
]
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
prefill_time
:
float
decode_time
:
float
prefill_count
:
int
decode_count
:
int
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment