Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
5e3c6b4f
Unverified
Commit
5e3c6b4f
authored
Feb 27, 2025
by
wang jiahao
Committed by
GitHub
Feb 27, 2025
Browse files
Merge pull request #644 from wtdcode/temperature_top_p_from_request
Allow temperature and top_p from /v1/chat/completions
parents
1f28f75f
b121ca4d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
25 additions
and
17 deletions
+25
-17
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+2
-2
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+2
-2
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+5
-5
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+11
-7
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-1
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+2
-0
No files found.
ktransformers/server/api/openai/endpoints/chat.py
View file @
5e3c6b4f
...
@@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
...
@@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
if
create
.
stream
:
if
create
.
stream
:
async
def
inner
():
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
):
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
chunk
.
set_token
(
token
)
chunk
.
set_token
(
token
)
yield
chunk
yield
chunk
return
chat_stream_response
(
request
,
inner
())
return
chat_stream_response
(
request
,
inner
())
else
:
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion'
,
created
=
int
(
time
()))
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
comp
.
usage
=
Usage
(
completion_tokens
=
1
,
prompt_tokens
=
1
,
total_tokens
=
2
)
async
for
token
in
interface
.
inference
(
input_message
,
id
):
async
for
token
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/api/openai/legacy/completions.py
View file @
5e3c6b4f
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
if
create
.
stream
:
async
def
inner
():
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
return
stream_response
(
request
,
inner
())
else
:
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/backend/interfaces/ktransformers.py
View file @
5e3c6b4f
...
@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
class
KTransformersThreadContext
(
TransformersThreadContext
):
...
@@ -128,7 +128,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -128,7 +128,7 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
@@ -203,7 +203,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -203,7 +203,7 @@ class KTransformersInterface(TransformersInterface):
if
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -212,7 +212,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -212,7 +212,7 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
5e3c6b4f
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
:
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
self
.
args
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
top_p
=
top_p
,
temperature
=
self
.
args
.
temperature
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)
,
temperature
,
top_p
):
# output think token after prefill done
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
5e3c6b4f
...
@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
...
@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
messages
:
List
[
Message
]
messages
:
List
[
Message
]
model
:
str
model
:
str
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
5e3c6b4f
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model
:
str
model
:
str
prompt
:
str
|
List
[
str
]
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment