Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
289ffa3c
"...core/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "688448db7547be90203440cfd105703d8a853f39"
Unverified
Commit
289ffa3c
authored
Jul 19, 2023
by
lvhan028
Committed by
GitHub
Jul 19, 2023
Browse files
fix the offset during streaming chat (#142)
parent
79595cd1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
25 deletions
+22
-25
lmdeploy/model.py
lmdeploy/model.py
+1
-1
lmdeploy/serve/client.py
lmdeploy/serve/client.py
+1
-1
lmdeploy/serve/turbomind/chatbot.py
lmdeploy/serve/turbomind/chatbot.py
+20
-23
No files found.
lmdeploy/model.py
View file @
289ffa3c
...
@@ -113,7 +113,7 @@ conversation""" # noqa: E501
...
@@ -113,7 +113,7 @@ conversation""" # noqa: E501
def
get_prompt
(
self
,
prompt
,
sequence_start
=
True
):
def
get_prompt
(
self
,
prompt
,
sequence_start
=
True
):
if
sequence_start
:
if
sequence_start
:
return
f
'<
bos
>
{
self
.
system
}
\n
'
\
return
f
'<
BOS
>
{
self
.
system
}
\n
'
\
f
'
{
self
.
user
}
:
{
prompt
}{
self
.
eoh
}
\n
'
\
f
'
{
self
.
user
}
:
{
prompt
}{
self
.
eoh
}
\n
'
\
f
'
{
self
.
assistant
}
:'
f
'
{
self
.
assistant
}
:'
else
:
else
:
...
...
lmdeploy/serve/client.py
View file @
289ffa3c
...
@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
...
@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
chatbot
.
end
(
session_id
)
chatbot
.
end
(
session_id
)
else
:
else
:
request_id
=
f
'
{
session_id
}
-
{
nth_round
}
'
request_id
=
f
'
{
session_id
}
-
{
nth_round
}
'
for
status
,
res
,
token
s
in
chatbot
.
stream_infer
(
for
status
,
res
,
n_
token
in
chatbot
.
stream_infer
(
session_id
,
session_id
,
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
...
...
lmdeploy/serve/turbomind/chatbot.py
View file @
289ffa3c
...
@@ -381,13 +381,16 @@ class Chatbot:
...
@@ -381,13 +381,16 @@ class Chatbot:
request_output_len
,
sequence_start
,
request_output_len
,
sequence_start
,
sequence_end
,
preseq_length
,
cancel
))
sequence_end
,
preseq_length
,
cancel
))
producer
.
start
()
producer
.
start
()
for
state
,
res
,
tokens
in
self
.
stream_consumer
(
for
state
,
res
,
tokens
in
self
.
stream_consumer
(
self
.
postprocess
,
que
,
self
.
postprocess
,
que
,
session
,
preseq_length
,
cancel
,
logger
,
session
,
input_tokens
,
self
.
display
,
self
.
profile_generation
,
self
.
eos_id
):
preseq_length
,
cancel
,
logger
,
self
.
display
,
self
.
profile_generation
,
self
.
eos_id
):
if
state
.
value
<
0
:
if
state
.
value
<
0
:
yield
state
,
res
,
0
yield
state
,
res
,
0
else
:
else
:
yield
state
,
res
,
tokens
-
input_tokens
yield
state
,
res
,
tokens
producer
.
join
()
producer
.
join
()
self
.
_session
=
que
.
get
()
self
.
_session
=
que
.
get
()
curseq_length
=
self
.
_session
.
sequence_length
curseq_length
=
self
.
_session
.
sequence_length
...
@@ -477,8 +480,9 @@ class Chatbot:
...
@@ -477,8 +480,9 @@ class Chatbot:
que
.
put
(
None
)
que
.
put
(
None
)
@
staticmethod
@
staticmethod
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
preseq_length
,
cancel
,
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
n_input_token
,
logger
,
display
,
profile_generation
,
eos_id
):
preseq_length
,
cancel
,
logger
,
display
,
profile_generation
,
eos_id
):
"""Consume the response from the triton inference server.
"""Consume the response from the triton inference server.
Args:
Args:
...
@@ -486,6 +490,7 @@ class Chatbot:
...
@@ -486,6 +490,7 @@ class Chatbot:
the generated tokens
the generated tokens
res_queue (multiprocessing.Queue): response queue
res_queue (multiprocessing.Queue): response queue
session (Session): an instance of a session
session (Session): an instance of a session
n_input_token (int): token number of input prompt
preseq_length (int): the history sequence length
preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session
cancel (bool): indicator for cancelling the session
logger (util.Logger):
logger (util.Logger):
...
@@ -496,12 +501,12 @@ class Chatbot:
...
@@ -496,12 +501,12 @@ class Chatbot:
Yields:
Yields:
tuple: status, text, generated token number
tuple: status, text, generated token number
"""
"""
offset
=
n_input_token
+
preseq_length
while
True
:
while
True
:
result
=
res_queue
.
get
()
result
=
res_queue
.
get
()
if
result
is
None
:
if
result
is
None
:
yield
StatusCode
.
TRITON_STREAM_END
,
\
yield
(
StatusCode
.
TRITON_STREAM_END
,
session
.
response
,
session
.
response
[
len
(
session
.
prompt
):],
\
session
.
sequence_length
-
offset
)
session
.
sequence_length
-
preseq_length
session
.
status
=
StatusCode
.
TRITON_STREAM_END
session
.
status
=
StatusCode
.
TRITON_STREAM_END
break
break
if
'errcode'
in
result
:
if
'errcode'
in
result
:
...
@@ -521,7 +526,7 @@ class Chatbot:
...
@@ -521,7 +526,7 @@ class Chatbot:
output_ids
=
result
.
as_numpy
(
'output_ids'
)
output_ids
=
result
.
as_numpy
(
'output_ids'
)
session
.
sequence_length
=
sequence_length
.
squeeze
()
session
.
sequence_length
=
sequence_length
.
squeeze
()
sequence_length
=
sequence_length
-
preseq_length
sequence_length
=
sequence_length
-
offset
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
if
last_token_id
==
eos_id
:
if
last_token_id
==
eos_id
:
session
.
sequence_length
=
session
.
sequence_length
-
1
session
.
sequence_length
=
session
.
sequence_length
-
1
...
@@ -536,22 +541,14 @@ class Chatbot:
...
@@ -536,22 +541,14 @@ class Chatbot:
'postprocessing is ignored during profiling '
'postprocessing is ignored during profiling '
'token generation'
,
sequence_length
.
squeeze
())
'token generation'
,
sequence_length
.
squeeze
())
continue
continue
output_str
=
postprocess
(
output_ids
[:,
:,
preseq_length
:],
output_str
=
postprocess
(
output_ids
[:,
:,
offset
:],
sequence_length
)
sequence_length
)
text
=
output_str
[
0
].
decode
()
text
=
output_str
[
0
].
decode
()
if
display
:
if
display
:
if
len
(
text
)
>
len
(
session
.
prompt
):
if
session
.
status
==
StatusCode
.
TRITON_SESSION_READY
:
new_text
=
text
[
len
(
session
.
prompt
):]
session
.
status
=
StatusCode
.
TRITON_STREAM_ING
else
:
new_text
=
text
[
len
(
session
.
response
):]
new_text
=
text
[
len
(
session
.
response
):]
print
(
new_text
,
end
=
''
,
flush
=
True
)
print
(
new_text
,
end
=
''
,
flush
=
True
)
session
.
response
=
text
session
.
response
=
text
if
len
(
session
.
response
)
>
len
(
session
.
prompt
):
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
,
session
.
status
=
StatusCode
.
TRITON_STREAM_ING
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
[
len
(
session
.
prompt
):],
sequence_length
.
squeeze
())
sequence_length
.
squeeze
())
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
'catch exception:
{
e
}
'
)
logger
.
error
(
f
'catch exception:
{
e
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment