Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
8704c091
Unverified
Commit
8704c091
authored
Feb 24, 2025
by
lazymio
Browse files
Allow temperature and top_p from requests
parent
4b5991e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
12 deletions
+18
-12
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+2
-2
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+3
-3
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+11
-7
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+2
-0
No files found.
ktransformers/server/api/openai/legacy/completions.py
View file @
8704c091
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
if
create
.
stream
:
async
def
inner
():
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
return
stream_response
(
request
,
inner
())
else
:
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/backend/interfaces/ktransformers.py
View file @
8704c091
...
@@ -14,7 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -14,7 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
warm_uped
=
False
warm_uped
=
False
...
@@ -207,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -207,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
8704c091
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
:
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
self
.
args
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
top_p
=
top_p
,
temperature
=
self
.
args
.
temperature
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)
,
temperature
,
top_p
):
# output think token after prefill done
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
8704c091
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model
:
str
model
:
str
prompt
:
str
|
List
[
str
]
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
top_p
:
Optional
[
float
]
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment