Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
289ffa3c
Unverified
Commit
289ffa3c
authored
Jul 19, 2023
by
lvhan028
Committed by
GitHub
Jul 19, 2023
Browse files
fix the offset during streaming chat (#142)
parent
79595cd1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
25 deletions
+22
-25
lmdeploy/model.py
lmdeploy/model.py
+1
-1
lmdeploy/serve/client.py
lmdeploy/serve/client.py
+1
-1
lmdeploy/serve/turbomind/chatbot.py
lmdeploy/serve/turbomind/chatbot.py
+20
-23
No files found.
lmdeploy/model.py
View file @
289ffa3c
...
@@ -113,7 +113,7 @@ conversation""" # noqa: E501
...
@@ -113,7 +113,7 @@ conversation""" # noqa: E501
def
get_prompt
(
self
,
prompt
,
sequence_start
=
True
):
def
get_prompt
(
self
,
prompt
,
sequence_start
=
True
):
if
sequence_start
:
if
sequence_start
:
return
f
'<
bos
>
{
self
.
system
}
\n
'
\
return
f
'<
BOS
>
{
self
.
system
}
\n
'
\
f
'
{
self
.
user
}
:
{
prompt
}{
self
.
eoh
}
\n
'
\
f
'
{
self
.
user
}
:
{
prompt
}{
self
.
eoh
}
\n
'
\
f
'
{
self
.
assistant
}
:'
f
'
{
self
.
assistant
}
:'
else
:
else
:
...
...
lmdeploy/serve/client.py
View file @
289ffa3c
...
@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
...
@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
chatbot
.
end
(
session_id
)
chatbot
.
end
(
session_id
)
else
:
else
:
request_id
=
f
'
{
session_id
}
-
{
nth_round
}
'
request_id
=
f
'
{
session_id
}
-
{
nth_round
}
'
for
status
,
res
,
token
s
in
chatbot
.
stream_infer
(
for
status
,
res
,
n_
token
in
chatbot
.
stream_infer
(
session_id
,
session_id
,
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
...
...
lmdeploy/serve/turbomind/chatbot.py
View file @
289ffa3c
...
@@ -381,13 +381,16 @@ class Chatbot:
...
@@ -381,13 +381,16 @@ class Chatbot:
request_output_len
,
sequence_start
,
request_output_len
,
sequence_start
,
sequence_end
,
preseq_length
,
cancel
))
sequence_end
,
preseq_length
,
cancel
))
producer
.
start
()
producer
.
start
()
for
state
,
res
,
tokens
in
self
.
stream_consumer
(
for
state
,
res
,
tokens
in
self
.
stream_consumer
(
self
.
postprocess
,
que
,
self
.
postprocess
,
que
,
session
,
preseq_length
,
cancel
,
logger
,
session
,
input_tokens
,
self
.
display
,
self
.
profile_generation
,
self
.
eos_id
):
preseq_length
,
cancel
,
logger
,
self
.
display
,
self
.
profile_generation
,
self
.
eos_id
):
if
state
.
value
<
0
:
if
state
.
value
<
0
:
yield
state
,
res
,
0
yield
state
,
res
,
0
else
:
else
:
yield
state
,
res
,
tokens
-
input_tokens
yield
state
,
res
,
tokens
producer
.
join
()
producer
.
join
()
self
.
_session
=
que
.
get
()
self
.
_session
=
que
.
get
()
curseq_length
=
self
.
_session
.
sequence_length
curseq_length
=
self
.
_session
.
sequence_length
...
@@ -477,8 +480,9 @@ class Chatbot:
...
@@ -477,8 +480,9 @@ class Chatbot:
que
.
put
(
None
)
que
.
put
(
None
)
@
staticmethod
@
staticmethod
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
preseq_length
,
cancel
,
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
n_input_token
,
logger
,
display
,
profile_generation
,
eos_id
):
preseq_length
,
cancel
,
logger
,
display
,
profile_generation
,
eos_id
):
"""Consume the response from the triton inference server.
"""Consume the response from the triton inference server.
Args:
Args:
...
@@ -486,6 +490,7 @@ class Chatbot:
...
@@ -486,6 +490,7 @@ class Chatbot:
the generated tokens
the generated tokens
res_queue (multiprocessing.Queue): response queue
res_queue (multiprocessing.Queue): response queue
session (Session): an instance of a session
session (Session): an instance of a session
n_input_token (int): token number of input prompt
preseq_length (int): the history sequence length
preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session
cancel (bool): indicator for cancelling the session
logger (util.Logger):
logger (util.Logger):
...
@@ -496,12 +501,12 @@ class Chatbot:
...
@@ -496,12 +501,12 @@ class Chatbot:
Yields:
Yields:
tuple: status, text, generated token number
tuple: status, text, generated token number
"""
"""
offset
=
n_input_token
+
preseq_length
while
True
:
while
True
:
result
=
res_queue
.
get
()
result
=
res_queue
.
get
()
if
result
is
None
:
if
result
is
None
:
yield
StatusCode
.
TRITON_STREAM_END
,
\
yield
(
StatusCode
.
TRITON_STREAM_END
,
session
.
response
,
session
.
response
[
len
(
session
.
prompt
):],
\
session
.
sequence_length
-
offset
)
session
.
sequence_length
-
preseq_length
session
.
status
=
StatusCode
.
TRITON_STREAM_END
session
.
status
=
StatusCode
.
TRITON_STREAM_END
break
break
if
'errcode'
in
result
:
if
'errcode'
in
result
:
...
@@ -521,7 +526,7 @@ class Chatbot:
...
@@ -521,7 +526,7 @@ class Chatbot:
output_ids
=
result
.
as_numpy
(
'output_ids'
)
output_ids
=
result
.
as_numpy
(
'output_ids'
)
session
.
sequence_length
=
sequence_length
.
squeeze
()
session
.
sequence_length
=
sequence_length
.
squeeze
()
sequence_length
=
sequence_length
-
preseq_length
sequence_length
=
sequence_length
-
offset
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
if
last_token_id
==
eos_id
:
if
last_token_id
==
eos_id
:
session
.
sequence_length
=
session
.
sequence_length
-
1
session
.
sequence_length
=
session
.
sequence_length
-
1
...
@@ -536,23 +541,15 @@ class Chatbot:
...
@@ -536,23 +541,15 @@ class Chatbot:
'postprocessing is ignored during profiling '
'postprocessing is ignored during profiling '
'token generation'
,
sequence_length
.
squeeze
())
'token generation'
,
sequence_length
.
squeeze
())
continue
continue
output_str
=
postprocess
(
output_ids
[:,
:,
preseq_length
:],
output_str
=
postprocess
(
output_ids
[:,
:,
offset
:],
sequence_length
)
sequence_length
)
text
=
output_str
[
0
].
decode
()
text
=
output_str
[
0
].
decode
()
if
display
:
if
display
:
if
len
(
text
)
>
len
(
session
.
prompt
):
new_text
=
text
[
len
(
session
.
response
):]
if
session
.
status
==
StatusCode
.
TRITON_SESSION_READY
:
print
(
new_text
,
end
=
''
,
flush
=
True
)
new_text
=
text
[
len
(
session
.
prompt
):]
session
.
status
=
StatusCode
.
TRITON_STREAM_ING
else
:
new_text
=
text
[
len
(
session
.
response
):]
print
(
new_text
,
end
=
''
,
flush
=
True
)
session
.
response
=
text
session
.
response
=
text
if
len
(
session
.
response
)
>
len
(
session
.
prompt
):
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
,
session
.
status
=
StatusCode
.
TRITON_STREAM_ING
sequence_length
.
squeeze
())
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
[
len
(
session
.
prompt
):],
sequence_length
.
squeeze
())
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
'catch exception:
{
e
}
'
)
logger
.
error
(
f
'catch exception:
{
e
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment