Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
8704c091
Unverified
Commit
8704c091
authored
Feb 24, 2025
by
lazymio
Browse files
Allow temperature and top_p from requests
parent
4b5991e7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
12 deletions
+18
-12
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+2
-2
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+3
-3
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+11
-7
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+2
-0
No files found.
ktransformers/server/api/openai/legacy/completions.py
View file @
8704c091
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
if
create
.
stream
:
async
def
inner
():
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
return
stream_response
(
request
,
inner
())
else
:
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/backend/interfaces/ktransformers.py
View file @
8704c091
...
@@ -14,7 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -14,7 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
warm_uped
=
False
warm_uped
=
False
...
@@ -207,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -207,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
8704c091
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
:
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
self
.
args
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
top_p
=
top_p
,
temperature
=
self
.
args
.
temperature
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)
,
temperature
,
top_p
):
# output think token after prefill done
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
8704c091
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model
:
str
model
:
str
prompt
:
str
|
List
[
str
]
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
top_p
:
Optional
[
float
]
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment