Unverified Commit ad30d5cf authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

Benchmark with Pytorch Profiler easily (#2110)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent dfec7fca
......@@ -388,6 +388,24 @@ async def async_request_gserver(
raise NotImplementedError()
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
import huggingface_hub.constants
......@@ -836,12 +854,14 @@ def calculate_metrics(
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
request_rate: float,
disable_tqdm: bool,
extra_request_body: Dict[str, Any],
profile: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
......@@ -869,6 +889,14 @@ async def benchmark(
time.sleep(1.5)
if profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=base_url + "/start_profile"
)
if profile_output.success:
print("Profiler started")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()
......@@ -890,6 +918,12 @@ async def benchmark(
)
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
if profile_output.success:
print("Profiler stopped")
if pbar is not None:
pbar.close()
......@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/v1/models/model:predict"
)
base_url = (
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
)
# Get model name
if args.model is None:
......@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
else:
......@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
......@@ -1355,6 +1396,11 @@ if __name__ == "__main__":
type=str,
help="Path to load previously generated input data",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
args = parser.parse_args()
run_benchmark(args)
......@@ -564,6 +564,7 @@ def run_bench_serving(
disable_stream=disable_stream,
disable_ignore_eos=False,
extra_request_body=None,
profile=None,
)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment