Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
0ed1e4d4
Unverified
Commit
0ed1e4d4
authored
Aug 07, 2023
by
lvhan028
Committed by
GitHub
Aug 07, 2023
Browse files
Improve postprocessing in TIS serving by applying Incremental de-tokenizing (#197)
* change to incremental decoding * update
parent
18c386d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
25 deletions
+28
-25
benchmark/profile_serving.py
benchmark/profile_serving.py
+12
-11
lmdeploy/serve/client.py
lmdeploy/serve/client.py
+0
-1
lmdeploy/serve/turbomind/chatbot.py
lmdeploy/serve/turbomind/chatbot.py
+16
-13
No files found.
benchmark/profile_serving.py
View file @
0ed1e4d4
...
@@ -55,7 +55,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
...
@@ -55,7 +55,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
def
warmup
(
tritonserver_addr
:
str
,
def
warmup
(
tritonserver_addr
:
str
,
concurrency
:
int
,
concurrency
:
int
,
output_seqlen
:
int
,
output_seqlen
:
int
,
warmup_round
:
int
=
4
):
warmup_round
:
int
=
1
):
print
(
'start to warmup ...'
)
print
(
'start to warmup ...'
)
def
_infer
(
_chatbot
,
session_id
):
def
_infer
(
_chatbot
,
session_id
):
...
@@ -87,7 +87,7 @@ def warmup(tritonserver_addr: str,
...
@@ -87,7 +87,7 @@ def warmup(tritonserver_addr: str,
def
read_dataset
(
tokenizer_path
:
str
,
dataset_path
:
str
,
samples
:
int
,
def
read_dataset
(
tokenizer_path
:
str
,
dataset_path
:
str
,
samples
:
int
,
test_round
:
int
,
session_len
:
int
):
session_len
:
int
):
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
with
open
(
dataset_path
)
as
f
:
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
dataset
=
json
.
load
(
f
)
...
@@ -119,14 +119,12 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
...
@@ -119,14 +119,12 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
if
samples
>
0
:
if
samples
>
0
:
filtered_dataset
=
random
.
sample
(
filtered_dataset
,
samples
)
filtered_dataset
=
random
.
sample
(
filtered_dataset
,
samples
)
filtered_dataset
*=
test_round
random
.
shuffle
(
filtered_dataset
)
que
=
mp
.
Queue
()
que
=
mp
.
Queue
()
for
data
in
filtered_dataset
:
for
data
in
filtered_dataset
:
que
.
put
(
data
)
que
.
put
(
data
)
print
(
f
'elapsed time for filtering: '
print
(
f
'elapsed time for filtering: '
f
'
{
round
(
time
.
perf_counter
()
-
start
,
2
)
}
s'
)
f
'
{
round
(
time
.
perf_counter
()
-
start
,
2
)
}
s'
)
return
que
return
que
,
len
(
filtered_dataset
)
def
main
(
tritonserver_addr
:
str
,
def
main
(
tritonserver_addr
:
str
,
...
@@ -134,11 +132,10 @@ def main(tritonserver_addr: str,
...
@@ -134,11 +132,10 @@ def main(tritonserver_addr: str,
dataset_path
:
str
,
dataset_path
:
str
,
concurrency
:
int
=
1
,
concurrency
:
int
=
1
,
session_len
:
int
=
2048
,
session_len
:
int
=
2048
,
samples
:
int
=
1000
,
samples
:
int
=
1000
):
test_round
:
int
=
1
):
warmup
(
tritonserver_addr
,
concurrency
,
session_len
-
1
)
warmup
(
tritonserver_addr
,
concurrency
,
session_len
-
1
)
req_que
=
read_dataset
(
tokenizer_path
,
dataset_path
,
samples
,
test_round
,
req_que
,
n_req
=
read_dataset
(
tokenizer_path
,
dataset_path
,
samples
,
session_len
)
session_len
)
res_que
=
mp
.
Queue
()
res_que
=
mp
.
Queue
()
procs
=
[]
procs
=
[]
_start
=
time
.
perf_counter
()
_start
=
time
.
perf_counter
()
...
@@ -168,13 +165,17 @@ def main(tritonserver_addr: str,
...
@@ -168,13 +165,17 @@ def main(tritonserver_addr: str,
first_token_latency_min
=
np
.
min
(
stats
[:,
0
],
axis
=
0
)
first_token_latency_min
=
np
.
min
(
stats
[:,
0
],
axis
=
0
)
first_token_latency_max
=
np
.
max
(
stats
[:,
0
],
axis
=
0
)
first_token_latency_max
=
np
.
max
(
stats
[:,
0
],
axis
=
0
)
first_token_latency_ave
=
np
.
mean
(
stats
[:,
0
],
axis
=
0
)
first_token_latency_ave
=
np
.
mean
(
stats
[:,
0
],
axis
=
0
)
throughput
=
np
.
sum
(
stats
[:,
1
],
axis
=
0
)
/
elapsed_time
token_throughput
=
np
.
sum
(
stats
[:,
1
],
axis
=
0
)
/
elapsed_time
req_throughput
=
n_req
/
elapsed_time
print
(
f
'
\n
{
"-"
*
50
}
\n
concurrency:
{
concurrency
}
\n
'
print
(
f
'
\n
{
"-"
*
50
}
\n
concurrency:
{
concurrency
}
\n
'
f
'elapsed_time:
{
elapsed_time
:.
2
f
}
s
\n
'
f
'elapsed_time:
{
elapsed_time
:.
2
f
}
s
\n
'
f
'first_token latency(min, max, ave): '
f
'first_token latency(min, max, ave): '
f
'
{
first_token_latency_min
:.
2
f
}
s,
{
first_token_latency_max
:.
2
f
}
s, '
f
'
{
first_token_latency_min
:.
2
f
}
s,
{
first_token_latency_max
:.
2
f
}
s, '
f
'
{
first_token_latency_ave
:.
2
f
}
s
\n
'
f
'
{
first_token_latency_ave
:.
2
f
}
s
\n
'
f
'throughput:
{
throughput
:.
2
f
}
token/s
\n
{
"-"
*
50
}
'
)
f
'token throughput:
{
token_throughput
:.
2
f
}
token/s
\n
'
f
'req throughput:
{
req_throughput
}
req/s
\n
'
f
'
{
"-"
*
50
}
\n
'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
lmdeploy/serve/client.py
View file @
0ed1e4d4
...
@@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1):
...
@@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1):
Args:
Args:
tritonserver_addr (str): the address in format "ip:port" of
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
triton inference server
model_name (str): the name of the deployed model
session_id (int): the identical id of a session
session_id (int): the identical id of a session
"""
"""
log_level
=
os
.
environ
.
get
(
'SERVICE_LOG_LEVEL'
,
'WARNING'
)
log_level
=
os
.
environ
.
get
(
'SERVICE_LOG_LEVEL'
,
'WARNING'
)
...
...
lmdeploy/serve/turbomind/chatbot.py
View file @
0ed1e4d4
...
@@ -26,6 +26,7 @@ class Session:
...
@@ -26,6 +26,7 @@ class Session:
request_id
:
str
=
''
request_id
:
str
=
''
histories
:
str
=
''
# history conversations of the session
histories
:
str
=
''
# history conversations of the session
sequence_length
:
int
=
0
# the total generated token number in the session
sequence_length
:
int
=
0
# the total generated token number in the session
sequence_offset
:
int
=
0
# the new generated token offset in the session
prompt
:
str
=
''
prompt
:
str
=
''
response
:
str
=
''
response
:
str
=
''
status
:
int
=
None
# status of the session
status
:
int
=
None
# status of the session
...
@@ -539,14 +540,15 @@ class Chatbot:
...
@@ -539,14 +540,15 @@ class Chatbot:
Yields:
Yields:
tuple: status, text, generated token number
tuple: status, text, generated token number
"""
"""
offset
=
n_input_token
+
preseq_length
session
.
sequence_offset
=
n_input_token
+
preseq_length
sentinel
=
n_input_token
+
preseq_length
status
,
res
,
n_token
=
None
,
''
,
0
status
,
res
,
n_token
=
None
,
''
,
0
while
True
:
while
True
:
result
=
res_queue
.
get
()
result
=
res_queue
.
get
()
if
result
is
None
:
if
result
is
None
:
status
=
StatusCode
.
TRITON_STREAM_END
status
=
StatusCode
.
TRITON_STREAM_END
res
=
session
.
response
res
=
session
.
response
n_token
=
session
.
sequence_length
-
offset
n_token
=
session
.
sequence_length
-
sentinel
session
.
status
=
StatusCode
.
TRITON_STREAM_END
session
.
status
=
StatusCode
.
TRITON_STREAM_END
break
break
if
'errcode'
in
result
:
if
'errcode'
in
result
:
...
@@ -569,30 +571,31 @@ class Chatbot:
...
@@ -569,30 +571,31 @@ class Chatbot:
output_ids
=
result
.
as_numpy
(
'output_ids'
)
output_ids
=
result
.
as_numpy
(
'output_ids'
)
session
.
sequence_length
=
sequence_length
.
squeeze
()
session
.
sequence_length
=
sequence_length
.
squeeze
()
sequence
_length
=
sequence_length
-
offset
new_token
_length
=
sequence_length
-
session
.
sequence_
offset
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
last_token_id
=
output_ids
[
-
1
][
-
1
][
session
.
sequence_length
-
1
]
if
last_token_id
==
eos_id
:
if
last_token_id
==
eos_id
:
session
.
sequence_length
=
session
.
sequence_length
-
1
session
.
sequence_length
=
session
.
sequence_length
-
1
sequence
_length
=
sequence
_length
-
1
new_token
_length
=
new_token
_length
-
1
output_ids
=
output_ids
.
reshape
((
1
,
1
,
output_ids
.
shape
[
-
1
]))
output_ids
=
output_ids
.
reshape
((
1
,
1
,
output_ids
.
shape
[
-
1
]))
sequence
_length
=
sequence
_length
.
reshape
(
new_token
_length
=
new_token
_length
.
reshape
(
(
1
,
sequence
_length
.
shape
[
-
1
]))
(
1
,
new_token
_length
.
shape
[
-
1
]))
if
profile_generation
:
if
profile_generation
:
yield
(
StatusCode
.
TRITON_STREAM_ING
,
yield
(
StatusCode
.
TRITON_STREAM_ING
,
'postprocessing is ignored during profiling '
'postprocessing is ignored during profiling '
'token generation'
,
sequence
_length
.
squeeze
())
'token generation'
,
new_token
_length
.
squeeze
())
continue
continue
output_str
=
postprocess
(
output_ids
[:,
:,
offset
:],
output_str
=
postprocess
(
sequence_length
)
output_ids
[:,
:,
session
.
sequence_offset
:],
new_token_length
)
session
.
sequence_offset
=
session
.
sequence_length
text
=
output_str
[
0
].
decode
()
text
=
output_str
[
0
].
decode
()
if
display
:
if
display
:
new_text
=
text
[
len
(
session
.
response
):]
print
(
text
,
end
=
''
,
flush
=
True
)
print
(
new_text
,
end
=
''
,
flush
=
True
)
session
.
response
+=
text
session
.
response
=
text
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
,
yield
(
StatusCode
.
TRITON_STREAM_ING
,
session
.
response
,
se
quence_length
.
squeeze
()
)
se
ssion
.
sequence_offset
-
sentinel
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
'catch exception:
{
e
}
'
)
logger
.
error
(
f
'catch exception:
{
e
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment