Unverified Commit 98c00a2d authored by Yueyang Pan's avatar Yueyang Pan Committed by GitHub
Browse files

Fix torch profiler bugs for bench_offline_throughput.py (#6557)

parent 451ffe74
...@@ -52,6 +52,17 @@ ...@@ -52,6 +52,17 @@
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
``` ```
- Possible PyTorch Bug
If in any cases you encounter the following error (for example, using qwen 2.5 VL):
```bash
RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/profiler_python.cpp":983, please report a bug to PyTorch. Python replay stack is empty.
```
This is likely a PyTorch Bug reported in [Bug: vLLM Profiler](https://github.com/vllm-project/vllm/issues/18240) and [Bug: torch.profiler.profile](https://github.com/pytorch/pytorch/issues/101632). As a workaround, you may disable `with_stack` with an environment variable such as follows:
```bash
export SGLANG_PROFILE_WITH_STACK=False
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
```
- View Traces - View Traces
Trace files can be loaded and visualized from: Trace files can be loaded and visualized from:
......
...@@ -88,6 +88,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -88,6 +88,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` | | `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |
## Storage & Caching ## Storage & Caching
......
...@@ -11,7 +11,9 @@ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1 ...@@ -11,7 +11,9 @@ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1
""" """
import argparse import argparse
import asyncio
import dataclasses import dataclasses
import inspect
import json import json
import logging import logging
import os import os
...@@ -235,8 +237,10 @@ def throughput_test_once( ...@@ -235,8 +237,10 @@ def throughput_test_once(
latency = time.perf_counter() - st latency = time.perf_counter() - st
if profile: if profile:
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
known_files = set(os.listdir(dir))
backend.stop_profile() backend.stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR")) monitor_trace_file(known_files, dir)
if backend_name == "runtime": if backend_name == "runtime":
gen_out = json.loads(gen_out) gen_out = json.loads(gen_out)
...@@ -260,6 +264,10 @@ def throughput_test_once( ...@@ -260,6 +264,10 @@ def throughput_test_once(
measurement_results["total_input_tokens"] measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"] + measurement_results["total_output_tokens"]
) / latency ) / latency
if inspect.isawaitable(server_info):
server_info = asyncio.run(server_info)
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput" "last_gen_throughput"
] ]
...@@ -267,11 +275,9 @@ def throughput_test_once( ...@@ -267,11 +275,9 @@ def throughput_test_once(
return measurement_results return measurement_results
def monitor_trace_file(directory, interval=1): def monitor_trace_file(known_files, directory, interval=1):
print(f"Monitoring {directory} for new trace files...") print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory))
while True: while True:
flag = False flag = False
time.sleep(interval) time.sleep(interval)
......
...@@ -85,6 +85,22 @@ class RuntimeEndpoint(BaseBackend): ...@@ -85,6 +85,22 @@ class RuntimeEndpoint(BaseBackend):
) )
self._assert_success(res) self._assert_success(res)
def start_profile(self):
res = http_request(
self.base_url + "/start_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def stop_profile(self):
res = http_request(
self.base_url + "/stop_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def commit_lazy_operations(self, s: StreamExecutor): def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
...@@ -374,7 +390,8 @@ class Runtime: ...@@ -374,7 +390,8 @@ class Runtime:
self.pid = None self.pid = None
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False) pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
proc = multiprocessing.Process( ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(
target=launch_server, target=launch_server,
args=(self.server_args, pipe_writer), args=(self.server_args, pipe_writer),
) )
...@@ -406,6 +423,12 @@ class Runtime: ...@@ -406,6 +423,12 @@ class Runtime:
kill_process_tree(self.pid) kill_process_tree(self.pid)
self.pid = None self.pid = None
def start_profile(self):
self.endpoint.start_profile()
def stop_profile(self):
self.endpoint.stop_profile()
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix) self.endpoint.cache_prefix(prefix)
......
...@@ -116,6 +116,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams ...@@ -116,6 +116,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
dataclass_to_string_truncated, dataclass_to_string_truncated,
get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
...@@ -805,6 +806,8 @@ class TokenizerManager: ...@@ -805,6 +806,8 @@ class TokenizerManager:
profile_by_stage: bool = False, profile_by_stage: bool = False,
): ):
self.auto_create_handle_loop() self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
with_stack = False if with_stack is False or env_with_stack is False else True
req = ProfileReq( req = ProfileReq(
type=ProfileReqType.START_PROFILE, type=ProfileReqType.START_PROFILE,
output_dir=output_dir, output_dir=output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment