Unverified Commit ad30d5cf authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

Benchmark with Pytorch Profiler easily (#2110)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent dfec7fca
...@@ -388,6 +388,24 @@ async def async_request_gserver( ...@@ -388,6 +388,24 @@ async def async_request_gserver(
raise NotImplementedError() raise NotImplementedError()
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
def get_model(pretrained_model_name_or_path: str) -> str: def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
import huggingface_hub.constants import huggingface_hub.constants
...@@ -836,12 +854,14 @@ def calculate_metrics( ...@@ -836,12 +854,14 @@ def calculate_metrics(
async def benchmark( async def benchmark(
backend: str, backend: str,
api_url: str, api_url: str,
base_url: str,
model_id: str, model_id: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
request_rate: float, request_rate: float,
disable_tqdm: bool, disable_tqdm: bool,
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -869,6 +889,14 @@ async def benchmark( ...@@ -869,6 +889,14 @@ async def benchmark(
time.sleep(1.5) time.sleep(1.5)
if profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=base_url + "/start_profile"
)
if profile_output.success:
print("Profiler started")
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
...@@ -890,6 +918,12 @@ async def benchmark( ...@@ -890,6 +918,12 @@ async def benchmark(
) )
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
if profile_output.success:
print("Profiler stopped")
if pbar is not None: if pbar is not None:
pbar.close() pbar.close()
...@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url if args.base_url
else f"http://{args.host}:{args.port}/v1/models/model:predict" else f"http://{args.host}:{args.port}/v1/models/model:predict"
) )
base_url = (
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
)
# Get model name # Get model name
if args.model is None: if args.model is None:
...@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark( benchmark(
backend=backend, backend=backend,
api_url=api_url, api_url=api_url,
base_url=base_url,
model_id=model_id, model_id=model_id,
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
request_rate=args.request_rate, request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile,
) )
) )
else: else:
...@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark( benchmark(
backend=backend, backend=backend,
api_url=api_url, api_url=api_url,
base_url=base_url,
model_id=model_id, model_id=model_id,
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
request_rate=rate, request_rate=rate,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile,
) )
) )
...@@ -1355,6 +1396,11 @@ if __name__ == "__main__": ...@@ -1355,6 +1396,11 @@ if __name__ == "__main__":
type=str, type=str,
help="Path to load previously generated input data", help="Path to load previously generated input data",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
args = parser.parse_args() args = parser.parse_args()
run_benchmark(args) run_benchmark(args)
...@@ -564,6 +564,7 @@ def run_bench_serving( ...@@ -564,6 +564,7 @@ def run_bench_serving(
disable_stream=disable_stream, disable_stream=disable_stream,
disable_ignore_eos=False, disable_ignore_eos=False,
extra_request_body=None, extra_request_body=None,
profile=None,
) )
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment