Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b95d1275
Commit
b95d1275
authored
Mar 18, 2025
by
zhuwenwen
Browse files
update benchmarks
parent
d9e67e78
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1803 additions
and
12 deletions
+1803
-12
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+39
-6
vllm/benchmarks/backend_request_func.py
vllm/benchmarks/backend_request_func.py
+484
-0
vllm/benchmarks/benchmark_serving.py
vllm/benchmarks/benchmark_serving.py
+1241
-0
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+39
-6
No files found.
benchmarks/benchmark_throughput.py
View file @
b95d1275
...
@@ -5,6 +5,7 @@ import dataclasses
...
@@ -5,6 +5,7 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
pathlib
import
Path
from
functools
import
cache
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -215,12 +216,34 @@ def run_vllm(
...
@@ -215,12 +216,34 @@ def run_vllm(
use_beam_search
=
False
use_beam_search
=
False
if
not
use_beam_search
:
if
not
use_beam_search
:
start
=
time
.
perf_counter
()
if
args
.
profile
:
llm
.
generate
(
prompts
,
profile_dir
=
args
.
profile_result_dir
sampling_params
,
if
not
profile_dir
:
lora_request
=
lora_requests
,
profile_dir
=
Path
(
use_tqdm
=
True
)
"."
end
=
time
.
perf_counter
()
)
/
"vllm_benchmark_result"
/
f
"latency_result_
{
time
.
time
()
}
"
print
(
f
"Profiling (results will be saved to '
{
profile_dir
}
')..."
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
record_shapes
=
True
,
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
))
)
as
prof
:
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
lora_requests
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
print
(
'Prepare time report'
)
print
(
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=-
1
))
else
:
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
lora_requests
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
else
:
else
:
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
prompts
=
[
request
.
prompt
for
request
in
requests
]
prompts
=
[
request
.
prompt
for
request
in
requests
]
...
@@ -498,6 +521,16 @@ if __name__ == "__main__":
...
@@ -498,6 +521,16 @@ if __name__ == "__main__":
type
=
int
,
type
=
int
,
default
=
None
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
help
=
"Maximum batch size for HF backend."
)
parser
.
add_argument
(
'--profile'
,
action
=
'store_true'
,
help
=
'profile the generation process of a single batch'
)
parser
.
add_argument
(
'--profile-result-dir'
,
type
=
str
,
default
=
None
,
help
=
(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--output-json'
,
'--output-json'
,
type
=
str
,
type
=
str
,
...
...
vllm/benchmarks/backend_request_func.py
0 → 100644
View file @
b95d1275
# SPDX-License-Identifier: Apache-2.0
import
json
import
os
import
sys
import
time
import
traceback
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Union
import
aiohttp
import
huggingface_hub.constants
from
tqdm.asyncio
import
tqdm
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
@
dataclass
class
RequestFuncInput
:
prompt
:
str
api_url
:
str
prompt_len
:
int
output_len
:
int
model
:
str
model_name
:
Optional
[
str
]
=
None
best_of
:
int
=
1
logprobs
:
Optional
[
int
]
=
None
extra_body
:
Optional
[
dict
]
=
None
multi_modal_content
:
Optional
[
dict
]
=
None
ignore_eos
:
bool
=
False
@
dataclass
class
RequestFuncOutput
:
generated_text
:
str
=
""
success
:
bool
=
False
latency
:
float
=
0.0
output_tokens
:
int
=
0
ttft
:
float
=
0.0
# Time to first token
itl
:
List
[
float
]
=
field
(
default_factory
=
list
)
# List of inter-token latencies
tpot
:
float
=
0.0
# avg next-token latencies
prompt_len
:
int
=
0
error
:
str
=
""
async
def
async_request_tgi
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"generate_stream"
)
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
params
=
{
"best_of"
:
request_func_input
.
best_of
,
"max_new_tokens"
:
request_func_input
.
output_len
,
"do_sample"
:
True
,
"temperature"
:
0.01
,
# TGI does not accept 0.0 temperature.
"top_p"
:
0.99
,
# TGI does not accept 1.0 top_p.
"truncate"
:
request_func_input
.
prompt_len
,
# TGI does not accept ignore_eos flag.
}
payload
=
{
"inputs"
:
request_func_input
.
prompt
,
"parameters"
:
params
,
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk_bytes
=
chunk_bytes
.
decode
(
"utf-8"
)
# NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if
chunk_bytes
.
startswith
(
":"
):
continue
chunk
=
chunk_bytes
.
removeprefix
(
"data:"
)
data
=
json
.
loads
(
chunk
)
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
output
.
latency
=
most_recent_timestamp
-
st
output
.
success
=
True
output
.
generated_text
=
data
[
"generated_text"
]
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
async
def
async_request_trt_llm
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"generate_stream"
)
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
assert
request_func_input
.
best_of
==
1
payload
=
{
"accumulate_tokens"
:
True
,
"text_input"
:
request_func_input
.
prompt
,
"temperature"
:
0.0
,
"top_p"
:
1.0
,
"max_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
}
if
request_func_input
.
ignore_eos
:
payload
[
"min_length"
]
=
request_func_input
.
output_len
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data:"
)
data
=
json
.
loads
(
chunk
)
output
.
generated_text
+=
data
[
"text_output"
]
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
output
.
latency
=
most_recent_timestamp
-
st
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
async
def
async_request_deepspeed_mii
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
assert
request_func_input
.
best_of
==
1
payload
=
{
"prompt"
:
request_func_input
.
prompt
,
"max_tokens"
:
request_func_input
.
output_len
,
"temperature"
:
0.01
,
# deepspeed-mii does not accept 0.0 temp.
"top_p"
:
1.0
,
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
output
.
ttft
=
0
st
=
time
.
perf_counter
()
try
:
async
with
session
.
post
(
url
=
request_func_input
.
api_url
,
json
=
payload
)
as
response
:
if
response
.
status
==
200
:
parsed_resp
=
await
response
.
json
()
output
.
latency
=
time
.
perf_counter
()
-
st
output
.
generated_text
=
parsed_resp
[
"text"
][
0
]
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
(
"completions"
,
"profile"
)
),
"OpenAI Completions API URL must end with 'completions' or 'profile'."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
payload
=
{
"model"
:
request_func_input
.
model_name
\
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"prompt"
:
request_func_input
.
prompt
,
"temperature"
:
0.0
,
"best_of"
:
request_func_input
.
best_of
,
"max_tokens"
:
request_func_input
.
output_len
,
"logprobs"
:
request_func_input
.
logprobs
,
"stream"
:
True
,
"stream_options"
:
{
"include_usage"
:
True
,
},
}
if
request_func_input
.
ignore_eos
:
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
first_chunk_received
=
False
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
data
=
json
.
loads
(
chunk
)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if
choices
:
=
data
.
get
(
"choices"
):
# Note that text could be empty here
# e.g. for special tokens
text
=
choices
[
0
].
get
(
"text"
)
timestamp
=
time
.
perf_counter
()
# First token
if
not
first_chunk_received
:
first_chunk_received
=
True
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
generated_text
+=
text
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
if
first_chunk_received
:
output
.
success
=
True
else
:
output
.
success
=
False
output
.
error
=
(
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output
.
generated_text
=
generated_text
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
async
def
async_request_openai_chat_completions
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"chat/completions"
),
"OpenAI Chat Completions API URL must end with 'chat/completions'."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
if
request_func_input
.
multi_modal_content
:
content
.
append
(
request_func_input
.
multi_modal_content
)
payload
=
{
"model"
:
request_func_input
.
model_name
\
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
content
},
],
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"stream_options"
:
{
"include_usage"
:
True
,
},
}
if
request_func_input
.
ignore_eos
:
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
most_recent_timestamp
=
timestamp
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
def
get_model
(
pretrained_model_name_or_path
:
str
)
->
str
:
if
os
.
getenv
(
'VLLM_USE_MODELSCOPE'
,
'False'
).
lower
()
==
'true'
:
from
modelscope
import
snapshot_download
model_path
=
snapshot_download
(
model_id
=
pretrained_model_name_or_path
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
ignore_file_pattern
=
[
".*.pt"
,
".*.safetensors"
,
".*.bin"
])
return
model_path
return
pretrained_model_name_or_path
def
get_tokenizer
(
pretrained_model_name_or_path
:
str
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
if
pretrained_model_name_or_path
is
not
None
and
not
os
.
path
.
exists
(
pretrained_model_name_or_path
):
pretrained_model_name_or_path
=
get_model
(
pretrained_model_name_or_path
)
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
kwargs
[
"use_fast"
]
=
False
if
tokenizer_mode
==
"mistral"
:
try
:
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
except
ImportError
as
e
:
raise
ImportError
(
"MistralTokenizer requires vllm package.
\n
"
"Please install it with `pip install vllm` "
"to use mistral tokenizer mode."
)
from
e
return
MistralTokenizer
.
from_pretrained
(
str
(
pretrained_model_name_or_path
))
else
:
return
AutoTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
trust_remote_code
=
trust_remote_code
,
**
kwargs
,
)
ASYNC_REQUEST_FUNCS
=
{
"tgi"
:
async_request_tgi
,
"vllm"
:
async_request_openai_completions
,
"lmdeploy"
:
async_request_openai_completions
,
"deepspeed-mii"
:
async_request_deepspeed_mii
,
"openai"
:
async_request_openai_completions
,
"openai-chat"
:
async_request_openai_chat_completions
,
"tensorrt-llm"
:
async_request_trt_llm
,
"scalellm"
:
async_request_openai_completions
,
"sglang"
:
async_request_openai_completions
,
}
vllm/benchmarks/benchmark_serving.py
0 → 100644
View file @
b95d1275
This diff is collapsed.
Click to expand it.
vllm/benchmarks/benchmark_throughput.py
View file @
b95d1275
...
@@ -5,6 +5,7 @@ import dataclasses
...
@@ -5,6 +5,7 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
pathlib
import
Path
from
functools
import
cache
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -215,12 +216,34 @@ def run_vllm(
...
@@ -215,12 +216,34 @@ def run_vllm(
use_beam_search
=
False
use_beam_search
=
False
if
not
use_beam_search
:
if
not
use_beam_search
:
start
=
time
.
perf_counter
()
if
args
.
profile
:
llm
.
generate
(
prompts
,
profile_dir
=
args
.
profile_result_dir
sampling_params
,
if
not
profile_dir
:
lora_request
=
lora_requests
,
profile_dir
=
Path
(
use_tqdm
=
True
)
"."
end
=
time
.
perf_counter
()
)
/
"vllm_benchmark_result"
/
f
"latency_result_
{
time
.
time
()
}
"
print
(
f
"Profiling (results will be saved to '
{
profile_dir
}
')..."
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
record_shapes
=
True
,
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
))
)
as
prof
:
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
lora_requests
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
print
(
'Prepare time report'
)
print
(
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=-
1
))
else
:
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
lora_requests
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
else
:
else
:
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
prompts
=
[
request
.
prompt
for
request
in
requests
]
prompts
=
[
request
.
prompt
for
request
in
requests
]
...
@@ -498,6 +521,16 @@ if __name__ == "__main__":
...
@@ -498,6 +521,16 @@ if __name__ == "__main__":
type
=
int
,
type
=
int
,
default
=
None
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
help
=
"Maximum batch size for HF backend."
)
parser
.
add_argument
(
'--profile'
,
action
=
'store_true'
,
help
=
'profile the generation process of a single batch'
)
parser
.
add_argument
(
'--profile-result-dir'
,
type
=
str
,
default
=
None
,
help
=
(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--output-json'
,
'--output-json'
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment