Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
bf36547f
Unverified
Commit
bf36547f
authored
Feb 24, 2025
by
lazymio
Browse files
Also allow repetition_penalty
parent
8704c091
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
8 deletions
+11
-8
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+2
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+8
-6
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+1
-0
No files found.
ktransformers/server/api/openai/legacy/completions.py
View file @
bf36547f
...
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
repetition_penalty
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
...
...
@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
repetition_penalty
):
comp
.
append_token
(
token
)
return
comp
ktransformers/server/backend/interfaces/transformers.py
View file @
bf36547f
...
...
@@ -202,18 +202,20 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
):
if
temperature
is
None
:
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
self
.
args
.
top_p
if
repetition_penalty
is
None
:
repetition_penalty
=
self
.
args
.
repetition_penalty
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
repetition_penalty
=
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
...
...
@@ -259,7 +261,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
...
@@ -327,7 +329,7 @@ class TransformersInterface(BackendInterfaceBase):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
,
repetition_penalty
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
@@ -363,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
...
...
@@ -390,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
,
repetition_penalty
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
bf36547f
...
...
@@ -11,6 +11,7 @@ class CompletionCreate(BaseModel):
stream
:
bool
=
False
temperature
:
Optional
[
float
]
top_p
:
Optional
[
float
]
repetition_penalty
:
Optional
[
float
]
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment