Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
7c94df4b
Commit
7c94df4b
authored
Oct 28, 2024
by
liam
Browse files
🚑
️: back transformer.py bugs version, and fix typo error in local_chat.py
parent
dd1d8667
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
7 deletions
+19
-7
ktransformers/local_chat.py
ktransformers/local_chat.py
+1
-1
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+18
-6
No files found.
ktransformers/local_chat.py
View file @
7c94df4b
...
...
@@ -91,7 +91,7 @@ def local_chat():
generated
=
asyncio
.
run
(
async_inference
(
messages
))
his_content
+=
[
{
"role"
:
"user"
,
"content"
:
content
},
{
"role"
:
"assitant"
,
"content"
:
generated
},
{
"role"
:
"assi
s
tant"
,
"content"
:
generated
},
]
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
7c94df4b
...
...
@@ -164,7 +164,6 @@ class TransformersInterface(BackendInterfaceBase):
if
m
[
"role"
]
==
"system"
:
logger
.
warning
(
f
'change
{
m
[
"role"
]
}
to user'
)
m
[
"role"
]
=
"user"
new_messages
=
[
messages
[
0
]]
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
...
...
@@ -173,12 +172,25 @@ class TransformersInterface(BackendInterfaceBase):
else
:
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# logger.debug(f"last message: {new_messages[-1]}")
# input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
input_ids
=
self
.
tokenizer
.
apply_chat_template
([
new_messages
[
-
1
]],
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
else
:
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
"pt"
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
# We can only hope that the input_ids are the same
unequal_mask
=
torch
.
ne
(
x
,
y
)
unequal_positions
=
torch
.
nonzero
(
unequal_mask
)
num_unequal_elements
=
unequal_mask
.
sum
().
item
()
logger
.
warning
(
f
'num_unequal_elements:
{
num_unequal_elements
}
'
)
input_ids
=
input_ids
[:,
self
.
seq_length
:]
logger
.
debug
(
f
"get input ids of shape
{
input_ids
.
shape
}
"
)
return
input_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment