Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
311490a7
Unverified
Commit
311490a7
authored
Jun 14, 2023
by
Woosuk Kwon
Committed by
GitHub
Jun 14, 2023
Browse files
Add script for benchmarking serving throughput (#145)
parent
da5ddcd5
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
421 additions
and
415 deletions
+421
-415
benchmarks/benchmark_async_llm_server.py
benchmarks/benchmark_async_llm_server.py
+5
-3
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-0
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+237
-0
benchmarks/benchmark_text_completion.py
benchmarks/benchmark_text_completion.py
+0
-255
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+138
-28
benchmarks/launch_tgi_server.sh
benchmarks/launch_tgi_server.sh
+16
-0
benchmarks/trace.py
benchmarks/trace.py
+0
-116
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+3
-0
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+13
-7
examples/simple_fastapi_client.py
examples/simple_fastapi_client.py
+8
-6
No files found.
benchmarks/benchmark_async_llm_server.py
View file @
311490a7
...
@@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
...
@@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
prompts
=
[
f
"Tell me a story with more than
{
''
.
join
([
str
(
i
+
1
)]
*
5
)
}
words"
prompts
=
[
f
"Tell me a story with more than
{
''
.
join
([
str
(
i
+
1
)]
*
5
)
}
words"
for
i
in
range
(
args
.
n_threads
)]
for
i
in
range
(
args
.
n_threads
)]
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
headers
=
{
"User-Agent"
:
"CacheFlow Benchmark Client"
}
headers
=
{
"User-Agent"
:
"CacheFlow Benchmark Client"
}
ploads
=
[{
ploads
=
[{
"prompt"
:
p
,
"prompt"
:
p
,
...
@@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
...
@@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
}
for
p
in
prompts
]
}
for
p
in
prompts
]
def
send_request
(
results
,
i
):
def
send_request
(
results
,
i
):
response
=
requests
.
post
(
args
.
api_url
,
headers
=
headers
,
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
ploads
[
i
],
json
=
ploads
[
i
],
stream
=
True
)
stream
=
True
)
results
[
i
]
=
response
results
[
i
]
=
response
# use args.n_threads to prompt the backend
# use args.n_threads to prompt the backend
...
@@ -50,7 +51,8 @@ def main(args: argparse.Namespace):
...
@@ -50,7 +51,8 @@ def main(args: argparse.Namespace):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--api-url"
,
type
=
str
,
default
=
"http://localhost:8001/generate"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n-threads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n-threads"
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
benchmarks/benchmark_latency.py
View file @
311490a7
"""Benchmark the latency of processing a single batch of requests."""
import
argparse
import
argparse
import
time
import
time
...
...
benchmarks/benchmark_serving.py
0 → 100644
View file @
311490a7
"""Benchmark online serving throughput.
On the server side, run one of the following commands:
(CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend
\
--disable-log-requests --model <your_model>
(TGI backend)
./launch_hf_server.sh <your_model>
On the client side, run:
python benchmarks/benchmark_serving.py
\
--backend <backend>
\
--tokenizer <your_model> --dataset <target_dataset>
\
--request-rate <request_rate>
"""
import
argparse
import
asyncio
import
json
import
random
import
time
from
typing
import
AsyncGenerator
,
List
,
Tuple
import
aiohttp
import
numpy
as
np
from
transformers
import
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizerBase
# (prompt len, output len, latency)
REQUEST_LATENCY
:
List
[
Tuple
[
int
,
int
,
float
]]
=
[]
def
get_tokenizer
(
model_name
:
str
)
->
PreTrainedTokenizerBase
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"llama"
:
# A workaround for potential protobuf errors.
model_name
=
"hf-internal-testing/llama-tokenizer"
return
AutoTokenizer
.
from_pretrained
(
model_name
)
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[
(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Tokenize the prompts and completions.
prompts
=
[
prompt
for
prompt
,
_
in
dataset
]
prompt_token_ids
=
tokenizer
(
prompts
).
input_ids
completions
=
[
completion
for
_
,
completion
in
dataset
]
completion_token_ids
=
tokenizer
(
completions
).
input_ids
tokenized_dataset
=
[]
for
i
in
range
(
len
(
dataset
)):
output_len
=
len
(
completion_token_ids
[
i
])
tokenized_dataset
.
append
((
prompts
[
i
],
prompt_token_ids
[
i
],
output_len
))
# Filter out too long sequences.
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
prompt
,
prompt_token_ids
,
output_len
in
tokenized_dataset
:
prompt_len
=
len
(
prompt_token_ids
)
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
# Sample the requests.
sampled_requests
=
random
.
sample
(
filtered_dataset
,
num_requests
)
return
sampled_requests
async
def
get_request
(
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
request_rate
:
float
,
)
->
AsyncGenerator
[
Tuple
[
str
,
int
,
int
],
None
]:
input_requests
=
iter
(
input_requests
)
for
request
in
input_requests
:
yield
request
if
request_rate
==
float
(
"inf"
):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval
=
np
.
random
.
exponential
(
1.0
/
request_rate
)
# The next request will be sent after the interval.
await
asyncio
.
sleep
(
interval
)
async
def
send_request
(
backend
:
str
,
api_url
:
str
,
prompt
:
str
,
prompt_len
:
int
,
output_len
:
int
,
best_of
:
int
,
use_beam_search
:
bool
,
)
->
None
:
request_start_time
=
time
.
time
()
headers
=
{
"User-Agent"
:
"Benchmark Client"
}
if
backend
==
"cacheflow"
:
pload
=
{
"prompt"
:
prompt
,
"n"
:
1
,
"best_of"
:
best_of
,
"use_beam_search"
:
use_beam_search
,
"temperature"
:
0.0
if
use_beam_search
else
1.0
,
"top_p"
:
1.0
,
"max_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"stream"
:
False
,
}
elif
backend
==
"tgi"
:
assert
not
use_beam_search
params
=
{
"best_of"
:
best_of
,
"max_new_tokens"
:
output_len
,
"do_sample"
:
True
,
}
pload
=
{
"inputs"
:
prompt
,
"parameters"
:
params
,
}
else
:
raise
ValueError
(
f
"Unknown backend:
{
backend
}
"
)
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
while
True
:
async
with
session
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
)
as
response
:
chunks
=
[]
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunks
.
append
(
chunk
)
output
=
b
""
.
join
(
chunks
).
decode
(
"utf-8"
)
output
=
json
.
loads
(
output
)
# Re-send the request if it failed.
if
"error"
not
in
output
:
break
request_end_time
=
time
.
time
()
request_latency
=
request_end_time
-
request_start_time
REQUEST_LATENCY
.
append
((
prompt_len
,
output_len
,
request_latency
))
async
def
benchmark
(
backend
:
str
,
api_url
:
str
,
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
best_of
:
int
,
use_beam_search
:
bool
,
request_rate
:
float
,
)
->
None
:
tasks
:
List
[
asyncio
.
Task
]
=
[]
async
for
request
in
get_request
(
input_requests
,
request_rate
):
prompt
,
prompt_len
,
output_len
=
request
task
=
asyncio
.
create_task
(
send_request
(
backend
,
api_url
,
prompt
,
prompt_len
,
output_len
,
best_of
,
use_beam_search
))
tasks
.
append
(
task
)
await
asyncio
.
gather
(
*
tasks
)
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
tokenizer
=
get_tokenizer
(
args
.
tokenizer
)
input_requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
benchmark_start_time
=
time
.
time
()
asyncio
.
run
(
benchmark
(
args
.
backend
,
api_url
,
input_requests
,
args
.
best_of
,
args
.
use_beam_search
,
args
.
request_rate
))
benchmark_end_time
=
time
.
time
()
benchmark_time
=
benchmark_end_time
-
benchmark_start_time
print
(
f
"Total time:
{
benchmark_time
:.
2
f
}
s"
)
print
(
f
"Throughput:
{
args
.
num_prompts
/
benchmark_time
:.
2
f
}
requests/s"
)
# Compute the latency statistics.
avg_latency
=
np
.
mean
([
latency
for
_
,
_
,
latency
in
REQUEST_LATENCY
])
print
(
f
"Average latency:
{
avg_latency
:.
2
f
}
s"
)
avg_per_token_latency
=
np
.
mean
([
latency
/
(
prompt_len
+
output_len
)
for
prompt_len
,
output_len
,
latency
in
REQUEST_LATENCY
])
print
(
f
"Average latency per token:
{
avg_per_token_latency
:.
2
f
}
s"
)
avg_per_output_token_latency
=
np
.
mean
([
latency
/
output_len
for
_
,
output_len
,
latency
in
REQUEST_LATENCY
])
print
(
"Average latency per output token: "
f
"
{
avg_per_output_token_latency
:.
2
f
}
s"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the online serving throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"cacheflow"
,
choices
=
[
"cacheflow"
,
"tgi"
])
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
required
=
True
,
help
=
"Name or path of the tokenizer."
)
parser
.
add_argument
(
"--best-of"
,
type
=
int
,
default
=
1
,
help
=
"Generates `best_of` sequences per prompt and "
"returns the best one."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--request-rate"
,
type
=
float
,
default
=
float
(
"inf"
),
help
=
"Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/benchmark_text_completion.py
deleted
100644 → 0
View file @
da5ddcd5
import
argparse
import
logging
import
os
import
pickle
import
time
from
typing
import
List
from
tqdm
import
tqdm
from
transformers
import
AutoConfig
from
benchmark.trace
import
generate_text_completion_requests
from
cacheflow.master.server
import
(
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
logger
=
logging
.
getLogger
(
__name__
)
def
main
(
args
:
argparse
.
Namespace
):
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
# Generate requests.
requests
=
generate_text_completion_requests
(
args
.
dataset
,
args
.
request_rate
,
args
.
duration
,
args
.
seed
,
args
.
n1
,
args
.
n2
,
args
.
n3
,
args
.
n4
,
args
.
n6
,
args
.
n2_beam
,
args
.
n4_beam
,
args
.
n6_beam
,
args
.
n8_beam
,
)
# Warm up.
logger
.
info
(
'Warming up.'
)
num_warmup_requests
=
8
warmup_input_len
=
8
warmup_output_len
=
32
warmup_sampling_params
=
SamplingParams
(
n
=
1
,
temperature
=
1.0
,
top_p
=
0.99
,
max_num_steps
=
warmup_output_len
,
use_beam_search
=
False
,
stop_token_ids
=
set
(),
num_logprobs
=
0
,
context_window_size
=
None
,
)
for
_
in
range
(
num_warmup_requests
):
frontend
.
_add_query
([
0
]
*
warmup_input_len
,
warmup_sampling_params
)
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
while
True
:
server
.
step
()
if
not
server
.
has_unfinished_requests
():
break
# Start benchmarking.
logger
.
info
(
'Start benchmarking.'
)
# Initialize tqdm.
pbar
=
tqdm
(
total
=
len
(
requests
),
desc
=
'Finished requests'
)
finished
=
[]
server
.
scheduler
.
reset_stats
()
start_time
=
time
.
time
()
while
True
:
now
=
time
.
time
()
if
args
.
timeout
is
not
None
and
now
-
start_time
>
args
.
timeout
:
logger
.
info
(
'Timeout. Stop benchmarking.'
)
break
while
requests
:
if
requests
[
0
][
0
]
<=
now
-
start_time
:
request_time
,
input_tokens
,
sampling_params
=
requests
.
pop
(
0
)
frontend
.
_add_query
(
input_tokens
,
sampling_params
,
arrival_time
=
start_time
+
request_time
)
else
:
break
server
.
add_sequence_groups
(
frontend
.
get_inputs
())
updated_seq_groups
=
server
.
step
()
now
=
time
.
time
()
for
seq_group
in
updated_seq_groups
:
if
not
seq_group
.
is_finished
():
continue
arrival_time
=
seq_group
.
arrival_time
finish_time
=
now
for
seq
in
seq_group
.
get_seqs
():
seq_len
=
seq
.
get_len
()
output_len
=
seq_len
-
seq
.
prompt_len
finished
.
append
({
'group_id'
:
seq_group
.
group_id
,
'seq_id'
:
seq
.
seq_id
,
'arrival_time'
:
arrival_time
,
'finish_time'
:
finish_time
,
'prompt_len'
:
seq
.
prompt_len
,
'output_len'
:
output_len
,
})
pbar
.
update
(
1
)
if
not
(
requests
or
server
.
has_unfinished_requests
()):
break
pbar
.
close
()
logger
.
info
(
'Finish benchmarking. Saving stats.'
)
server
.
scheduler
.
save_stats
(
args
.
output_dir
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
'sequences.pkl'
),
'wb'
)
as
f
:
pickle
.
dump
(
finished
,
f
)
logger
.
info
(
'Done.'
)
def
get_model_name
(
model
:
str
)
->
str
:
OPT_MODELS
=
[
'opt-125m'
,
'opt-350m'
,
'opt-1.3b'
,
'opt-2.7b'
,
'opt-6.7b'
,
'opt-13b'
,
'opt-30b'
,
'opt-66b'
,
'opt-175b'
,
]
for
opt_model
in
OPT_MODELS
:
if
opt_model
in
model
:
return
opt_model
config
=
AutoConfig
.
from_pretrained
(
model
)
assert
config
.
model_type
==
'llama'
hidden_size
=
config
.
hidden_size
if
hidden_size
==
4096
:
return
'llama-7b'
elif
hidden_size
==
5120
:
return
'llama-13b'
elif
hidden_size
==
6656
:
return
'llama-30b'
elif
hidden_size
==
8192
:
return
'llama-65b'
else
:
raise
ValueError
(
f
'Unknown model:
{
model
}
'
)
def
get_dataset_name
(
dataset
:
str
)
->
str
:
if
'sharegpt'
in
dataset
.
lower
():
return
'sharegpt'
elif
'alpaca'
in
dataset
.
lower
():
return
'alpaca'
else
:
raise
ValueError
(
f
'Unknown dataset:
{
dataset
}
'
)
def
get_sampling_dir_name
(
n1
:
float
,
n2
:
float
,
n3
:
float
,
n4
:
float
,
n6
:
float
,
n2_beam
:
float
,
n4_beam
:
float
,
n6_beam
:
float
,
n8_beam
:
float
,
)
->
str
:
method
=
''
if
n1
>
0.0
:
method
=
'n1'
if
n1
==
1.0
else
method
+
f
'n1-
{
n1
}
-'
if
n2
>
0.0
:
method
=
'n2'
if
n2
==
1.0
else
method
+
f
'n2-
{
n2
}
-'
if
n3
>
0.0
:
method
=
'n3'
if
n3
==
1.0
else
method
+
f
'n3-
{
n3
}
-'
if
n4
>
0.0
:
method
=
'n4'
if
n4
==
1.0
else
method
+
f
'n4-
{
n4
}
-'
if
n6
>
0.0
:
method
=
'n6'
if
n6
==
1.0
else
method
+
f
'n6-
{
n6
}
-'
if
n2_beam
>
0.0
:
method
=
'n2-beam'
if
n2_beam
==
1.0
else
method
+
f
'n2-beam-
{
n2_beam
}
-'
if
n4_beam
>
0.0
:
method
=
'n4-beam'
if
n4_beam
==
1.0
else
method
+
f
'n4-beam-
{
n4_beam
}
-'
if
n6_beam
>
0.0
:
method
=
'n6-beam'
if
n6_beam
==
1.0
else
method
+
f
'n6-beam-
{
n6_beam
}
-'
if
n8_beam
>
0.0
:
method
=
'n8-beam'
if
n8_beam
==
1.0
else
method
+
f
'n8-beam-
{
n8_beam
}
-'
return
method
[:
-
1
]
if
method
.
endswith
(
'-'
)
else
method
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Benchmark the performance on a series of requests.'
)
parser
=
add_server_arguments
(
parser
)
parser
.
add_argument
(
'--output-dir'
,
type
=
str
,
help
=
'path to output directory'
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
help
=
'path to dataset'
,
required
=
True
)
parser
.
add_argument
(
'--request-rate'
,
type
=
float
,
help
=
'reqs/sec'
,
required
=
True
)
parser
.
add_argument
(
'--duration'
,
type
=
int
,
help
=
'duration in seconds'
,
required
=
True
)
parser
.
add_argument
(
'--do-memory-analysis'
,
action
=
'store_true'
,
help
=
'do memory analysis (This will lower the throughput. Use this only for analysis.)'
)
parser
.
add_argument
(
'--timeout'
,
type
=
int
,
help
=
'time out in seconds'
,
default
=
None
)
parser
.
add_argument
(
'--n1'
,
type
=
float
,
help
=
'ratio of requests with n=1'
,
default
=
0.0
)
parser
.
add_argument
(
'--n2'
,
type
=
float
,
help
=
'ratio of requests with n=2'
,
default
=
0.0
)
parser
.
add_argument
(
'--n3'
,
type
=
float
,
help
=
'ratio of requests with n=3'
,
default
=
0.0
)
parser
.
add_argument
(
'--n4'
,
type
=
float
,
help
=
'ratio of requests with n=4'
,
default
=
0.0
)
parser
.
add_argument
(
'--n6'
,
type
=
float
,
help
=
'ratio of requests with n=6'
,
default
=
0.0
)
parser
.
add_argument
(
'--n2-beam'
,
type
=
float
,
help
=
'ratio of requests with n=2 & beam search'
,
default
=
0.0
)
parser
.
add_argument
(
'--n4-beam'
,
type
=
float
,
help
=
'ratio of requests with n=4 & beam search'
,
default
=
0.0
)
parser
.
add_argument
(
'--n6-beam'
,
type
=
float
,
help
=
'ratio of requests with n=6 & beam search'
,
default
=
0.0
)
parser
.
add_argument
(
'--n8-beam'
,
type
=
float
,
help
=
'ratio of requests with n=8 & beam search'
,
default
=
0.0
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
if
args
.
n1
+
args
.
n2
+
args
.
n3
+
args
.
n4
+
args
.
n6
+
args
.
n2_beam
+
args
.
n4_beam
+
args
.
n6_beam
+
args
.
n8_beam
!=
1.0
:
raise
ValueError
(
'The ratios of requests must sum to 1.'
)
model_name
=
get_model_name
(
args
.
model
)
dataset_name
=
get_dataset_name
(
args
.
dataset
)
if
'opt'
in
model_name
:
if
'opt'
not
in
args
.
dataset
.
lower
():
raise
ValueError
(
f
'OPT models can only be used with OPT datasets.'
)
elif
'llama'
in
model_name
:
if
'llama'
not
in
args
.
dataset
.
lower
():
raise
ValueError
(
f
'Llama models can only be used with Llama datasets.'
)
dataset_name
=
'sharegpt'
if
'sharegpt'
in
args
.
dataset
else
'alpaca'
sample_dir
=
get_sampling_dir_name
(
args
.
n1
,
args
.
n2
,
args
.
n3
,
args
.
n4
,
args
.
n6
,
args
.
n2_beam
,
args
.
n4_beam
,
args
.
n6_beam
,
args
.
n8_beam
)
if
args
.
output_dir
is
None
:
args
.
output_dir
=
os
.
path
.
join
(
'../exp'
,
dataset_name
,
f
'
{
model_name
}
-tp
{
args
.
tensor_parallel_size
}
'
,
sample_dir
,
'cacheflow'
,
f
'block
{
args
.
block_size
}
'
,
f
'req-rate-
{
args
.
request_rate
}
'
,
f
'seed
{
args
.
seed
}
'
,
f
'duration-
{
args
.
duration
}
'
,
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# Set up logging.
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
handlers
=
[
logging
.
StreamHandler
(),
logging
.
FileHandler
(
os
.
path
.
join
(
args
.
output_dir
,
'log.txt'
)),
],
)
logger
.
info
(
args
)
main
(
args
)
benchmarks/benchmark_throughput.py
View file @
311490a7
"""Benchmark offline inference throughput."""
import
argparse
import
argparse
import
json
import
json
import
random
import
random
...
@@ -5,14 +6,29 @@ import time
...
@@ -5,14 +6,29 @@ import time
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
cacheflow
import
LLM
,
SamplingParams
from
cacheflow
import
LLM
,
SamplingParams
from
transformers
import
PreTrainedTokenizerBase
import
torch
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
)
from
tqdm
import
tqdm
def
get_tokenizer
(
model_name
:
str
)
->
PreTrainedTokenizerBase
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"llama"
:
# A workaround for potential protobuf errors.
model_name
=
"hf-internal-testing/llama-tokenizer"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
# To enable padding in the HF backend.
tokenizer
.
pad_token
=
tokenizer
.
eos_token
return
tokenizer
return
AutoTokenizer
.
from_pretrained
(
model_name
)
def
sample_requests
(
def
sample_requests
(
dataset_path
:
str
,
dataset_path
:
str
,
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
List
[
Tuple
[
List
[
int
]
,
int
]]:
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
# Load the dataset.
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
dataset
=
json
.
load
(
f
)
...
@@ -35,45 +51,52 @@ def sample_requests(
...
@@ -35,45 +51,52 @@ def sample_requests(
tokenized_dataset
=
[]
tokenized_dataset
=
[]
for
i
in
range
(
len
(
dataset
)):
for
i
in
range
(
len
(
dataset
)):
output_len
=
len
(
completion_token_ids
[
i
])
output_len
=
len
(
completion_token_ids
[
i
])
tokenized_dataset
.
append
((
prompt_token_ids
[
i
],
output_len
))
tokenized_dataset
.
append
((
prompts
[
i
],
prompt_token_ids
[
i
],
output_len
))
# Filter out if the prompt length + output length is greater than 2048.
tokenized_dataset
=
[
# Filter out too long sequences.
(
prompt_token_ids
,
output_len
)
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
prompt_token_ids
,
output_len
in
tokenized_dataset
for
prompt
,
prompt_token_ids
,
output_len
in
tokenized_dataset
:
if
len
(
prompt_token_ids
)
+
output_len
<=
2048
prompt_len
=
len
(
prompt_token_ids
)
]
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
# Sample the requests.
# Sample the requests.
sampled_requests
=
random
.
sample
(
tokeniz
ed_dataset
,
num_requests
)
sampled_requests
=
random
.
sample
(
filter
ed_dataset
,
num_requests
)
return
sampled_requests
return
sampled_requests
def
main
(
args
:
argparse
.
Namespace
):
def
run_cacheflow
(
print
(
args
)
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
random
.
seed
(
args
.
seed
)
model
:
str
,
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
)
->
float
:
llm
=
LLM
(
llm
=
LLM
(
model
=
args
.
model
,
model
=
model
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
args
.
seed
,
seed
=
seed
,
)
)
tokenizer
=
llm
.
get_tokenizer
()
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
# Add the requests to the server.
# Add the requests to the server.
for
prompt
_token_ids
,
output_len
in
requests
:
for
prompt
,
_
,
output_len
in
requests
:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
args
.
n
,
n
=
n
,
temperature
=
0.0
if
args
.
use_beam_search
else
1.0
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
top_p
=
1.0
,
use_beam_search
=
args
.
use_beam_search
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
max_tokens
=
output_len
,
)
)
# FIXME(woosuk): Do not use internal method.
# FIXME(woosuk): Do not use internal method.
llm
.
_add_request
(
llm
.
_add_request
(
prompt
=
None
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
)
)
...
@@ -81,17 +104,95 @@ def main(args: argparse.Namespace):
...
@@ -81,17 +104,95 @@ def main(args: argparse.Namespace):
# FIXME(woosuk): Do use internal method.
# FIXME(woosuk): Do use internal method.
llm
.
_run_server
(
use_tqdm
=
True
)
llm
.
_run_server
(
use_tqdm
=
True
)
end
=
time
.
time
()
end
=
time
.
time
()
return
end
-
start
def
run_hf
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
n
:
int
,
use_beam_search
:
bool
,
max_batch_size
:
int
,
)
->
float
:
assert
not
use_beam_search
tokenizer
=
get_tokenizer
(
model
)
llm
=
AutoModelForCausalLM
.
from_pretrained
(
model
,
torch_dtype
=
torch
.
float16
)
llm
=
llm
.
cuda
()
pbar
=
tqdm
(
total
=
len
(
requests
))
start
=
time
.
time
()
batch
:
List
[
str
]
=
[]
max_prompt_len
=
0
max_output_len
=
0
for
i
in
range
(
len
(
requests
)):
prompt
,
prompt_len
,
output_len
=
requests
[
i
]
# Add the prompt to the batch.
batch
.
append
(
prompt
)
max_prompt_len
=
max
(
max_prompt_len
,
prompt_len
)
max_output_len
=
max
(
max_output_len
,
output_len
)
if
len
(
batch
)
<
max_batch_size
and
i
!=
len
(
requests
)
-
1
:
# Check if we can add more requests to the batch.
_
,
next_prompt_len
,
next_output_len
=
requests
[
i
+
1
]
if
(
max
(
max_prompt_len
,
next_prompt_len
)
+
max
(
max_output_len
,
next_output_len
))
<=
2048
:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
padding
=
True
).
input_ids
llm_outputs
=
llm
.
generate
(
input_ids
=
input_ids
.
cuda
(),
do_sample
=
not
use_beam_search
,
num_return_sequences
=
n
,
temperature
=
1.0
,
top_p
=
1.0
,
use_cache
=
True
,
max_new_tokens
=
max_output_len
,
)
# Include the decoding time.
tokenizer
.
batch_decode
(
llm_outputs
,
skip_special_tokens
=
True
)
pbar
.
update
(
len
(
batch
))
# Clear the batch.
batch
=
[]
max_prompt_len
=
0
max_output_len
=
0
end
=
time
.
time
()
return
end
-
start
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
tokenizer
=
get_tokenizer
(
args
.
model
)
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
if
args
.
backend
==
"cacheflow"
:
elapsed_time
=
run_cacheflow
(
requests
,
args
.
model
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
)
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
args
.
use_beam_search
,
args
.
hf_max_batch_size
)
else
:
raise
ValueError
(
f
"Unknown backend:
{
args
.
backend
}
"
)
total_num_tokens
=
sum
(
total_num_tokens
=
sum
(
len
(
prompt_
token_ids
)
+
output_len
prompt_
len
+
output_len
for
prompt_
token_ids
,
output_len
in
requests
for
_
,
prompt_
len
,
output_len
in
requests
)
)
elapsed_time
=
end
-
start
print
(
f
"Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
print
(
f
"Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
choices
=
[
"cacheflow"
,
"hf"
],
default
=
"cacheflow"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"Path to the dataset."
)
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
...
@@ -102,5 +203,14 @@ if __name__ == "__main__":
...
@@ -102,5 +203,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--hf-max-batch-size"
,
type
=
int
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
backend
==
"cacheflow"
:
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
elif
args
.
backend
==
"hf"
:
if
args
.
hf_max_batch_size
is
None
:
raise
ValueError
(
"HF max batch size is required for HF backend."
)
main
(
args
)
main
(
args
)
benchmarks/launch_tgi_server.sh
0 → 100755
View file @
311490a7
#!/bin/bash
PORT
=
8001
MODEL
=
$1
TOKENS
=
$2
docker run
--gpus
all
--shm-size
1g
-p
$PORT
:80
\
-v
$PWD
/data:/data
\
ghcr.io/huggingface/text-generation-inference:0.8
\
--model-id
$MODEL
\
--sharded
false
\
--max-input-length
1024
\
--max-total-tokens
2048
\
--max-best-of
5
\
--max-concurrent-requests
5000
\
--max-batch-total-tokens
$TOKENS
benchmarks/trace.py
deleted
100644 → 0
View file @
da5ddcd5
import
pickle
import
random
from
typing
import
List
,
Tuple
import
numpy
as
np
from
cacheflow.sampling_params
import
SamplingParams
def
generate_text_completion_requests
(
dataset
:
str
,
request_rate
:
float
,
duration
:
int
,
seed
:
int
,
n1
:
float
=
0.0
,
n2
:
float
=
0.0
,
n3
:
float
=
0.0
,
n4
:
float
=
0.0
,
n6
:
float
=
0.0
,
n2_beam
:
float
=
0.0
,
n4_beam
:
float
=
0.0
,
n6_beam
:
float
=
0.0
,
n8_beam
:
float
=
0.0
,
max_seq_len
:
int
=
2048
,
time_quantum
:
int
=
10
,
)
->
List
[
Tuple
[
float
,
List
[
int
],
SamplingParams
]]:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
# Generate timestamps for requests using Poisson distribution.
lam
=
request_rate
*
(
time_quantum
/
1000
)
quantums_per_sec
=
1000
/
time_quantum
arrival_times
=
np
.
random
.
poisson
(
lam
=
lam
,
size
=
int
(
duration
*
quantums_per_sec
))
timestamps
=
[]
for
i
,
n
in
enumerate
(
arrival_times
):
timestamps
+=
[
i
*
(
time_quantum
/
1000
)]
*
n
# Load and shuffle the dataset.
num_requests
=
len
(
timestamps
)
with
open
(
dataset
,
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
filtered
=
[]
for
pair
in
data
:
input_tokens
,
output_tokens
=
pair
input_len
=
len
(
input_tokens
)
output_len
=
len
(
output_tokens
)
# Filter out too long sequences.
if
input_len
+
output_len
<
max_seq_len
:
# Output tokens are not needed for the benchmark.
filtered
.
append
((
input_tokens
,
output_len
))
data
=
[]
while
len
(
data
)
<
num_requests
:
data
+=
filtered
data
=
data
[:
num_requests
]
# Shuffle the data.
assert
len
(
data
)
==
len
(
timestamps
)
random
.
shuffle
(
data
)
random_sampling_params_dict
=
{
'temperature'
:
1.0
,
'top_p'
:
1.0
,
'use_beam_search'
:
False
,
'stop_token_ids'
:
set
(),
'num_logprobs'
:
0
,
'context_window_size'
:
None
,
}
beam_search_params_dict
=
{
'temperature'
:
0.0
,
'top_p'
:
1.0
,
'use_beam_search'
:
True
,
'stop_token_ids'
:
set
(),
'num_logprobs'
:
0
,
'context_window_size'
:
None
,
}
# Generate requests based on the sampling parameter ratio.
requests
=
[]
assert
n1
+
n2
+
n3
+
n4
+
n6
+
n2_beam
+
n4_beam
+
n6_beam
+
n8_beam
==
1.0
cum_sum
=
0
for
timestamp
,
pair
in
zip
(
timestamps
,
data
):
input_tokens
,
output_len
=
pair
if
cum_sum
<
n1
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
1
,
max_num_steps
=
output_len
,
**
random_sampling_params_dict
)
elif
cum_sum
<
(
n1
+
n2
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
2
,
max_num_steps
=
output_len
,
**
random_sampling_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
3
,
max_num_steps
=
output_len
,
**
random_sampling_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
4
,
max_num_steps
=
output_len
,
**
random_sampling_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
+
n6
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
6
,
max_num_steps
=
output_len
,
**
random_sampling_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
+
n6
+
n2_beam
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
2
,
max_num_steps
=
output_len
,
**
beam_search_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
+
n6
+
n2_beam
+
n4_beam
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
4
,
max_num_steps
=
output_len
,
**
beam_search_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
+
n6
+
n2_beam
+
n4_beam
+
n6_beam
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
6
,
max_num_steps
=
output_len
,
**
beam_search_params_dict
)
elif
cum_sum
<
(
n1
+
n2
+
n3
+
n4
+
n6
+
n2_beam
+
n4_beam
+
n6_beam
+
n8_beam
)
*
num_requests
:
sampling_params
=
SamplingParams
(
n
=
8
,
max_num_steps
=
output_len
,
**
beam_search_params_dict
)
else
:
raise
ValueError
(
'Invalid request ratio.'
)
cum_sum
+=
1
requests
.
append
((
timestamp
,
input_tokens
,
sampling_params
))
return
requests
cacheflow/server/arg_utils.py
View file @
311490a7
...
@@ -120,6 +120,7 @@ class ServerArgs:
...
@@ -120,6 +120,7 @@ class ServerArgs:
class
AsyncServerArgs
(
ServerArgs
):
class
AsyncServerArgs
(
ServerArgs
):
"""Arguments for asynchronous CacheFlow servers."""
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray
:
bool
=
False
server_use_ray
:
bool
=
False
disable_log_requests
:
bool
=
False
@
staticmethod
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
...
@@ -129,4 +130,6 @@ class AsyncServerArgs(ServerArgs):
...
@@ -129,4 +130,6 @@ class AsyncServerArgs(ServerArgs):
parser
.
add_argument
(
'--server-use-ray'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--server-use-ray'
,
action
=
'store_true'
,
help
=
'use Ray to start the LLM server in a '
help
=
'use Ray to start the LLM server in a '
'separate process as the web server process.'
)
'separate process as the web server process.'
)
parser
.
add_argument
(
'--disable-log-requests'
,
action
=
'store_true'
,
help
=
'disable logging requests'
)
return
parser
return
parser
cacheflow/server/async_llm_server.py
View file @
311490a7
...
@@ -32,12 +32,14 @@ class AsyncLLMServer:
...
@@ -32,12 +32,14 @@ class AsyncLLMServer:
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
async frontend will be executed in a separate process as the
model workers.
model workers.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMServer.
*args, *kwargs: Arguments for LLMServer.
"""
"""
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
log_requests
:
bool
=
True
,
*
args
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
self
.
server_use_ray
=
server_use_ray
self
.
server_use_ray
=
server_use_ray
self
.
log_requests
=
log_requests
if
not
self
.
server_use_ray
:
if
not
self
.
server_use_ray
:
server_class
=
LLMServer
server_class
=
LLMServer
elif
self
.
worker_use_ray
:
elif
self
.
worker_use_ray
:
...
@@ -106,6 +108,7 @@ class AsyncLLMServer:
...
@@ -106,6 +108,7 @@ class AsyncLLMServer:
request_event
=
asyncio
.
Event
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
self
.
request_events
[
request_id
]
=
request_event
if
self
.
log_requests
:
logger
.
info
(
f
"Received request
{
request_id
}
: "
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"sampling params:
{
sampling_params
}
, "
...
@@ -152,6 +155,7 @@ class AsyncLLMServer:
...
@@ -152,6 +155,7 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
():
if
request_output
.
finished
():
if
self
.
log_requests
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
...
@@ -176,6 +180,7 @@ class AsyncLLMServer:
...
@@ -176,6 +180,7 @@ class AsyncLLMServer:
# The request has already finished or been aborted.
# The request has already finished or been aborted.
return
return
if
self
.
log_requests
:
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
server_use_ray
:
if
self
.
server_use_ray
:
...
@@ -206,6 +211,7 @@ class AsyncLLMServer:
...
@@ -206,6 +211,7 @@ class AsyncLLMServer:
# Create the LLM server.
# Create the LLM server.
server
=
cls
(
server_args
.
worker_use_ray
,
server
=
cls
(
server_args
.
worker_use_ray
,
server_args
.
server_use_ray
,
server_args
.
server_use_ray
,
not
server_args
.
disable_log_requests
,
*
server_configs
,
*
server_configs
,
distributed_init_method
,
devices
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
log_stats
=
not
server_args
.
disable_log_stats
)
...
...
examples/simple_fastapi_client.py
View file @
311490a7
import
argparse
import
argparse
import
json
import
json
import
requests
from
typing
import
Iterable
,
List
from
typing
import
Iterable
,
List
def
clear_line
(
n
:
int
=
1
)
->
None
:
import
requests
def
clear_line
(
n
:
int
=
1
)
->
None
:
LINE_UP
=
'
\033
[1A'
LINE_UP
=
'
\033
[1A'
LINE_CLEAR
=
'
\x1b
[2K'
LINE_CLEAR
=
'
\x1b
[2K'
for
i
in
range
(
n
):
for
_
in
range
(
n
):
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
...
@@ -53,7 +55,7 @@ if __name__ == "__main__":
...
@@ -53,7 +55,7 @@ if __name__ == "__main__":
n
=
args
.
n
n
=
args
.
n
stream
=
args
.
stream
stream
=
args
.
stream
print
(
f
"Prompt:
{
prompt
}
\n
"
,
flush
=
True
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
"
,
flush
=
True
)
response
=
post_http_request
(
prompt
,
api_url
,
n
,
stream
)
response
=
post_http_request
(
prompt
,
api_url
,
n
,
stream
)
if
stream
:
if
stream
:
...
@@ -63,8 +65,8 @@ if __name__ == "__main__":
...
@@ -63,8 +65,8 @@ if __name__ == "__main__":
num_printed_lines
=
0
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
else
:
else
:
output
=
get_response
(
response
)
output
=
get_response
(
response
)
for
i
,
line
in
enumerate
(
output
):
for
i
,
line
in
enumerate
(
output
):
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment