Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
cd9f7f8f
Commit
cd9f7f8f
authored
Feb 17, 2025
by
ceerrep
Browse files
fix: server: drop <think> tag in chat template
parent
ca2090d8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
2 deletions
+7
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+7
-2
No files found.
ktransformers/server/backend/interfaces/transformers.py
View file @
cd9f7f8f
...
...
@@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
m
[
"content"
]
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
...
...
@@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
...
...
@@ -360,6 +364,7 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment