Unverified Commit aaf7af1b authored by Neelabh Sinha's avatar Neelabh Sinha Committed by GitHub
Browse files

[FEATURE] Add Profile Trace Merger for Distributed Traces (#11413)

parent 932e2637
......@@ -74,6 +74,47 @@ python3 -m sglang.test.send_one
python3 -m sglang.profiler
```
### Profiler Trace Merger for Distributed Traces
SGLang now supports automatic merging of profiling traces from distributed setups with multiple parallelism types (TP, DP, PP, EP). This feature is particularly useful for analyzing performance across distributed runs.
#### Multi-Node Profiling and Shared Storage Considerations
Single-node profiler output merging is completely supported. When profiling in distributed environments spanning multiple nodes, shared storage (e.g., NFS, Lustre) should be accessible by all nodes for the output directory to enable merging of trace files.
If there is no shared storage accessible across nodes, automatic merging of trace files during profiling is not supported directly as of now.
#### HTTP API Usage
```bash
# Start profiling with automatic trace merging enabled
curl -X POST <BASE_URL>/start_profile \
-H "Content-Type: application/json" \
-d '{
"output_dir": "/tmp/profiles", # where to store profile traces
"num_steps": 10,
"activities": ["CPU", "GPU"],
"merge_profiles": true # optional argument to merge profile traces (default=False)
}'
```
#### Command Line Usage
```bash
# Start profiling with merge enabled
python -m sglang.profiler \
--num-steps 10 \
--activities CPU GPU \
--output-dir /tmp/profiles \
--merge-profiles # optional argument to merge profile traces (default=False)
```
#### Output Files
The profile merger generates:
- Individual rank trace files: `{profile_id}-TP-{tp}-DP-{dp}-PP-{pp}-EP-{ep}.trace.json.gz`
- Merged trace file: `merged-{profile_id}.trace.json.gz`
### Possible PyTorch bugs
If in any cases you encounter the following error (for example, using qwen 2.5 VL):
```bash
......
......@@ -25,6 +25,7 @@ def _run_profile(
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
) -> str:
if output_dir is None:
output_dir = PROFILER_DIR
......@@ -60,6 +61,7 @@ def _run_profile(
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
"merge_profiles": merge_profiles,
}
response = requests.post(url=url + "/start_profile", json=json_data)
......@@ -76,10 +78,17 @@ def run_profile(
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
url,
num_steps,
activities,
output_dir,
profile_name,
profile_by_stage,
merge_profiles,
)
return link
......@@ -145,6 +154,13 @@ if __name__ == "__main__":
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
parser.add_argument(
"--merge-profiles",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to merge profiles from all ranks into a single trace file",
)
args = parser.parse_args()
activities = []
......@@ -163,4 +179,5 @@ if __name__ == "__main__":
args.output_dir,
args.profile_name,
args.profile_by_stage,
args.merge_profiles,
)
......@@ -634,6 +634,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
with_stack=obj.with_stack,
record_shapes=obj.record_shapes,
profile_by_stage=obj.profile_by_stage,
merge_profiles=obj.merge_profiles,
)
return Response(
content="Start profiling.\n",
......
......@@ -1232,6 +1232,8 @@ class ProfileReqInput(BaseReq):
profile_by_stage: bool = False
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
# Merge profiles from all ranks into a single trace
merge_profiles: bool = False
class ProfileReqType(Enum):
......@@ -1250,6 +1252,8 @@ class ProfileReq(BaseReq):
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
profile_id: Optional[str] = None
# Merge profiles from all ranks into a single trace
merge_profiles: bool = False
@dataclass
......
......@@ -9,6 +9,7 @@ import torch
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_npu
from sglang.srt.utils.profile_merger import ProfileMerger
_is_npu = is_npu()
if _is_npu:
......@@ -25,7 +26,6 @@ logger = logging.getLogger(__name__)
class SchedulerProfilerMixin:
def init_profiler(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
......@@ -41,6 +41,7 @@ class SchedulerProfilerMixin:
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
self.merge_profiles = False
def init_profile(
self,
......@@ -52,6 +53,7 @@ class SchedulerProfilerMixin:
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
merge_profiles: bool = False,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
......@@ -60,6 +62,7 @@ class SchedulerProfilerMixin:
)
self.profile_by_stage = profile_by_stage
self.merge_profiles = merge_profiles
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
......@@ -169,6 +172,38 @@ class SchedulerProfilerMixin:
return ProfileReqOutput(success=True, message="Succeeded")
def _merge_profile_traces(self) -> str:
if not self.merge_profiles:
return ""
if self.tp_rank != 0:
return ""
if getattr(self, "dp_size", 1) > 1 and getattr(self, "dp_rank", 0) != 0:
return ""
if getattr(self, "pp_size", 1) > 1 and getattr(self, "pp_rank", 0) != 0:
return ""
if getattr(self, "moe_ep_size", 1) > 1 and getattr(self, "moe_ep_rank", 0) != 0:
return ""
try:
logger.info("Starting profile merge...")
merger = ProfileMerger(self.torch_profiler_output_dir, self.profile_id)
merged_path = merger.merge_chrome_traces()
summary = merger.get_merge_summary()
merge_message = (
f" Merged trace: {merged_path} "
f"(Events: {summary.get('total_events', '?')}, "
f"Files: {summary.get('total_files', '?')})"
)
logger.info(f"Profile merge completed: {merged_path}")
except Exception as e:
logger.error(f"Failed to merge profiles: {e}", exc_info=True)
return f" Merge failed: {e!s}"
else:
return merge_message
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
......@@ -186,14 +221,21 @@ class SchedulerProfilerMixin:
if self.torch_profiler is not None:
self.torch_profiler.stop()
if not _is_npu:
# Build filename with only non-zero ranks to maintain backward compatibility
filename_parts = [self.profile_id, f"TP-{self.tp_rank}"]
# Only add other ranks if parallelism is enabled (size > 1)
if getattr(self, "dp_size", 1) > 1:
filename_parts.append(f"DP-{getattr(self, 'dp_rank', 0)}")
if getattr(self, "pp_size", 1) > 1:
filename_parts.append(f"PP-{getattr(self, 'pp_rank', 0)}")
if getattr(self, "moe_ep_size", 1) > 1:
filename_parts.append(f"EP-{getattr(self, 'moe_ep_rank', 0)}")
filename = "-".join(filename_parts) + stage_suffix + ".trace.json.gz"
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
os.path.join(self.torch_profiler_output_dir, filename)
)
torch.distributed.barrier(self.tp_cpu_group)
......@@ -224,15 +266,18 @@ class SchedulerProfilerMixin:
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
merge_message = self._merge_profile_traces()
logger.info(
"Profiling done. Traces are saved to: %s",
"Profiling done. Traces are saved to: %s%s",
self.torch_profiler_output_dir,
merge_message,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
return ProfileReqOutput(success=True, message=f"Succeeded.{merge_message}")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
......@@ -282,6 +327,7 @@ class SchedulerProfilerMixin:
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
recv_req.merge_profiles,
)
else:
self.init_profile(
......@@ -293,6 +339,7 @@ class SchedulerProfilerMixin:
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
recv_req.merge_profiles,
)
return self.start_profile()
else:
......
......@@ -306,6 +306,7 @@ class TokenizerCommunicatorMixin:
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
):
self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
......@@ -320,6 +321,7 @@ class TokenizerCommunicatorMixin:
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
merge_profiles=merge_profiles,
)
return await self._execute_profile(req)
......
"""Merge Chrome trace files from multiple ranks (TP, DP, PP, EP) into a single trace."""
import glob
import gzip
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class ProfileMerger:
"""Merge profile traces from all parallelism types: TP, DP, PP, EP."""
def __init__(self, output_dir: str, profile_id: str):
self.output_dir = output_dir
self.profile_id = profile_id
self.merged_trace_path = os.path.join(
output_dir, f"merged-{profile_id}.trace.json.gz"
)
# Rank types in priority order (used for sorting and labeling)
self.rank_types = ["tp", "dp", "pp", "ep"]
# Sort index multipliers: DP (highest) > EP > PP > TP (lowest)
# These ensure proper visual ordering in trace viewer
self.sort_index_multipliers = {
"dp_rank": 100_000_000,
"ep_rank": 1_000_000,
"pp_rank": 10_000,
"tp_rank": 100,
}
# PID threshold for sort_index updates (only update for system PIDs < 1000)
self.pid_sort_index_threshold = 1000
def merge_chrome_traces(self) -> str:
"""Merge Chrome traces from all ranks into a single trace.
Returns:
Path to merged trace file.
Raises:
ValueError: If no trace files found.
"""
trace_files = self._discover_trace_files()
if not trace_files:
raise ValueError(f"No trace files found for profile_id: {self.profile_id}")
logger.info(f"Found {len(trace_files)} trace files to merge")
merged_trace = {"traceEvents": []}
all_device_properties = []
for trace_file in sorted(trace_files, key=self._get_rank_sort_key):
rank_info = self._extract_rank_info(trace_file)
logger.info(f"Processing {trace_file} with rank info: {rank_info}")
output = self._handle_file(trace_file, rank_info)
merged_trace["traceEvents"].extend(output["traceEvents"])
if "deviceProperties" in output:
all_device_properties.extend(output["deviceProperties"])
del output["deviceProperties"]
for key, value in output.items():
if key != "traceEvents" and key not in merged_trace:
merged_trace[key] = value
if all_device_properties:
merged_trace["deviceProperties"] = all_device_properties
with gzip.open(self.merged_trace_path, "wb") as f:
f.write(json.dumps(merged_trace).encode("utf-8"))
logger.info(f"Merged profile saved to: {self.merged_trace_path}")
logger.info(f"Total events merged: {len(merged_trace['traceEvents'])}")
return self.merged_trace_path
def _discover_trace_files(self) -> List[str]:
"""Discover trace files matching profile_id (supports TP/DP/PP/EP formats)."""
patterns = [f"{self.profile_id}*.trace.json.gz"]
trace_files = []
for pattern in patterns:
search_pattern = os.path.join(self.output_dir, pattern)
trace_files.extend(glob.glob(search_pattern))
trace_files = [
f
for f in trace_files
if not f.endswith(f"merged-{self.profile_id}.trace.json.gz")
and not f.endswith("-memory.pickle")
and "TP-" in f
]
trace_files = list(set(trace_files))
return trace_files
def _extract_rank_info(self, filename: str) -> Dict[str, int]:
"""Extract rank info (TP/DP/PP/EP) from filename."""
basename = os.path.basename(filename)
rank_info = {}
for rank_type in self.rank_types:
match = re.search(rf"{rank_type.upper()}-(\d+)", basename)
if match:
rank_info[f"{rank_type}_rank"] = int(match.group(1))
return rank_info
def _create_rank_label(self, rank_info: Dict[str, int]) -> str:
parts = []
for rank_type in self.rank_types:
rank_key = f"{rank_type}_rank"
if rank_key in rank_info:
parts.append(f"{rank_type.upper()}{rank_info[rank_key]:02d}")
return f"[{'-'.join(parts)}]" if parts else "[Unknown]"
def _handle_file(self, path: str, rank_info: Dict[str, int]) -> Dict[str, Any]:
logger.info(f"Processing file: {path}")
try:
with gzip.open(path, "rt", encoding="utf-8") as f:
trace = json.load(f)
output = {
key: value for key, value in trace.items() if key != "traceEvents"
}
output["traceEvents"] = self._process_events(
trace.get("traceEvents", []), rank_info
)
return output
except Exception as e:
logger.error(f"Failed to process trace file {path}: {e}")
return {"traceEvents": []}
def _process_events(
self, events: List[Dict], rank_info: Dict[str, int]
) -> List[Dict]:
"""Process events: update sort_index and add rank labels to PIDs."""
rank_label = self._create_rank_label(rank_info)
for event in events:
if event.get("name") == "process_sort_index":
pid = self._maybe_cast_int(event.get("pid"))
if pid is not None and pid < self.pid_sort_index_threshold:
event["args"]["sort_index"] = self._calculate_sort_index(
rank_info, pid
)
event["pid"] = f"{rank_label} {event['pid']}"
return events
def _calculate_sort_index(self, rank_info: Dict[str, int], pid: int) -> int:
sort_index = pid
for rank_type, multiplier in self.sort_index_multipliers.items():
sort_index += rank_info.get(rank_type, 0) * multiplier
return sort_index
def _get_rank_sort_key(self, path: str) -> Tuple[int, int, int, int]:
rank_info = self._extract_rank_info(path)
return tuple(
rank_info.get(f"{rank_type}_rank", 0)
for rank_type in ["dp", "ep", "pp", "tp"]
)
def _maybe_cast_int(self, x) -> Optional[int]:
try:
return int(x)
except (ValueError, TypeError):
return None
def get_merge_summary(self) -> Dict[str, Any]:
if not os.path.exists(self.merged_trace_path):
return {"error": "Merged trace file not found"}
try:
with gzip.open(self.merged_trace_path, "rt") as f:
merged_data = json.load(f)
trace_files = self._discover_trace_files()
return {
"merged_file": self.merged_trace_path,
"total_events": len(merged_data.get("traceEvents", [])),
"total_files": len(trace_files),
"source_files": [os.path.basename(f) for f in trace_files],
"profile_id": self.profile_id,
"device_properties_count": len(merged_data.get("deviceProperties", [])),
}
except Exception as e:
return {"error": f"Failed to read merged trace: {str(e)}"}
......@@ -115,6 +115,8 @@ suites = {
TestFile("test_srt_engine.py", 261),
TestFile("test_standalone_speculative_decoding.py", 250),
TestFile("test_start_profile.py", 60),
TestFile("test_profile_merger.py", 60),
TestFile("test_profile_merger_http_api.py", 15),
TestFile("test_swa_unittest.py", 1),
TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172),
......
"""
Unit tests for the ProfileMerger implementation.
Usage:
python test_profile_merger.py
python -m unittest test_profile_merger.py -v
"""
import gzip
import json
import os
import shutil
import tempfile
import unittest
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqInput, ProfileReqType
from sglang.srt.utils.profile_merger import ProfileMerger
class TestProfileMerger(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.profile_id = "test_profile_123"
self.merger = ProfileMerger(self.temp_dir, self.profile_id)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_rank_extraction_and_labeling(self):
# Test TP-only
filename = f"{self.profile_id}-TP-0.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP00]")
# Test all parallelism types
filename = f"{self.profile_id}-TP-1-DP-2-PP-3-EP-4.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(
rank_info, {"tp_rank": 1, "dp_rank": 2, "pp_rank": 3, "ep_rank": 4}
)
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP01-DP02-PP03-EP04]")
# Test partial ranks
filename = f"{self.profile_id}-TP-0-DP-1.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0, "dp_rank": 1})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP00-DP01]")
# Test no ranks
filename = f"{self.profile_id}.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[Unknown]")
def test_sort_index_calculation(self):
# Single rank
rank_info = {"tp_rank": 0}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertEqual(sort_idx, 83)
# Multiple ranks
rank_info = {"tp_rank": 1, "dp_rank": 2, "pp_rank": 3, "ep_rank": 4}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertNotEqual(sort_idx, 83)
self.assertGreater(sort_idx, 1000000)
# Empty ranks
rank_info = {}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertEqual(sort_idx, 83)
def test_rank_sort_key(self):
# Full ranks: TP-1, DP-2, PP-3, EP-4 → sorted as (DP, EP, PP, TP)
filename = f"{self.profile_id}-TP-1-DP-2-PP-3-EP-4.trace.json.gz"
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (2, 4, 3, 1))
# Missing ranks: only TP-1 → sorted as (DP=0, EP=0, PP=0, TP=1)
filename = f"{self.profile_id}-TP-1.trace.json.gz"
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (0, 0, 0, 1))
def test_discover_trace_files(self):
# Create mock trace files
trace_files = [
f"{self.profile_id}-TP-0.trace.json.gz", # Old format
f"{self.profile_id}-TP-1.trace.json.gz", # Old format
f"{self.profile_id}-TP-0-DP-1.trace.json.gz", # New format
]
for filename in trace_files:
filepath = os.path.join(self.temp_dir, filename)
with gzip.open(filepath, "wt") as f:
json.dump({"traceEvents": []}, f)
discovered = self.merger._discover_trace_files()
self.assertEqual(len(discovered), 3)
# Check that all expected files are discovered
discovered_basenames = {os.path.basename(f) for f in discovered}
expected_basenames = {
f"{self.profile_id}-TP-0.trace.json.gz",
f"{self.profile_id}-TP-1.trace.json.gz",
f"{self.profile_id}-TP-0-DP-1.trace.json.gz",
}
self.assertEqual(discovered_basenames, expected_basenames)
# Test no matches
empty_merger = ProfileMerger(self.temp_dir, "nonexistent")
discovered = empty_merger._discover_trace_files()
self.assertEqual(len(discovered), 0)
def test_merge_chrome_traces(self):
# Create multiple trace files in random order
trace_files = [
{
"filename": f"{self.profile_id}-TP-1-DP-1.trace.json.gz",
"events": [
{"ph": "X", "name": "op1", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
},
{
"filename": f"{self.profile_id}-TP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op2", "pid": 84, "ts": 2000.0, "dur": 15.0}
],
},
{
"filename": f"{self.profile_id}-TP-0-DP-1.trace.json.gz",
"events": [
{"ph": "X", "name": "op3", "pid": 85, "ts": 3000.0, "dur": 20.0}
],
},
]
for trace_data in trace_files:
filepath = os.path.join(self.temp_dir, trace_data["filename"])
trace_content = {
"schemaVersion": 1,
"deviceProperties": [{"device_id": 0, "name": "GPU-0"}],
"traceEvents": trace_data["events"],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_content, f)
# Test file ordering by capturing log messages
import logging
logger = logging.getLogger("sglang.srt.utils.profile_merger")
with self.assertLogs(logger, level="INFO") as log_capture:
merged_path = self.merger.merge_chrome_traces()
# Verify files were processed in rank order
log_messages = [
record.getMessage()
for record in log_capture.records
if "Processing file:" in record.getMessage()
]
self.assertIn("TP-0.trace.json.gz", log_messages[0]) # (0,0,0,0) comes first
self.assertIn(
"TP-0-DP-1.trace.json.gz", log_messages[1]
) # (0,1,0,0) comes second
self.assertIn(
"TP-1-DP-1.trace.json.gz", log_messages[2]
) # (1,1,0,0) comes last
# Verify merged content
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 3)
self.assertEqual(len(merged_data["deviceProperties"]), 3)
# Check rank labels in events
events = merged_data["traceEvents"]
pids = [event["pid"] for event in events]
self.assertIn("[TP00] 84", pids)
self.assertIn("[TP00-DP01] 85", pids)
self.assertIn("[TP01-DP01] 83", pids)
# Test merge summary
summary = self.merger.get_merge_summary()
self.assertEqual(summary["total_files"], 3)
self.assertEqual(summary["total_events"], 3)
self.assertEqual(summary["profile_id"], self.profile_id)
# Test no files error
empty_merger = ProfileMerger(self.temp_dir, "nonexistent")
with self.assertRaises(ValueError):
empty_merger.merge_chrome_traces()
class TestProfileMergerIntegration(unittest.TestCase):
def test_data_structures_merge_profiles(self):
# Test ProfileReqInput
req_input = ProfileReqInput()
self.assertFalse(req_input.merge_profiles)
req_input = ProfileReqInput(merge_profiles=True)
self.assertTrue(req_input.merge_profiles)
# Test ProfileReq
req = ProfileReq(type=ProfileReqType.START_PROFILE)
self.assertFalse(req.merge_profiles)
req = ProfileReq(type=ProfileReqType.START_PROFILE, merge_profiles=True)
self.assertTrue(req.merge_profiles)
def test_integration_parameters(self):
import inspect
# Test TokenizerManager
from sglang.srt.managers.tokenizer_communicator_mixin import (
TokenizerCommunicatorMixin,
)
sig = inspect.signature(TokenizerCommunicatorMixin.start_profile)
self.assertIn("merge_profiles", sig.parameters)
# Test SchedulerProfilerMixin
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
sig = inspect.signature(SchedulerProfilerMixin.init_profile)
self.assertIn("merge_profiles", sig.parameters)
# Test CLI profiler
from sglang.profiler import run_profile
sig = inspect.signature(run_profile)
self.assertIn("merge_profiles", sig.parameters)
class TestProfileMergerEdgeCases(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.profile_id = "test_edge_cases"
self.merger = ProfileMerger(self.temp_dir, self.profile_id)
def tearDown(self):
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_error_handling_and_edge_cases(self):
# Test malformed trace file
filename = f"{self.profile_id}-TP-0.trace.json.gz"
filepath = os.path.join(self.temp_dir, filename)
with gzip.open(filepath, "wt") as f:
f.write("invalid json content")
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 0)
# Test empty trace file
with gzip.open(filepath, "wt") as f:
json.dump({}, f)
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
# Test missing device properties
trace_data = {
"schemaVersion": 1,
"traceEvents": [
{"ph": "X", "name": "test", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_data, f)
merged_path = self.merger.merge_chrome_traces()
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertNotIn("deviceProperties", merged_data)
def test_missing_ranks_and_none_handling(self):
# Test rank extraction with missing ranks
filename = f"{self.profile_id}-TP-0.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0})
# Test rank label creation with missing ranks
label = self.merger._create_rank_label({"tp_rank": 0})
self.assertEqual(label, "[TP00]")
label = self.merger._create_rank_label({})
self.assertEqual(label, "[Unknown]")
# Test sort index calculation
sort_idx = self.merger._calculate_sort_index({"tp_rank": 0}, 83)
self.assertGreater(sort_idx, 0)
sort_idx = self.merger._calculate_sort_index({}, 83)
self.assertEqual(sort_idx, 83)
# Test sort key generation
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (0, 0, 0, 0))
# Test _maybe_cast_int with various inputs
self.assertIsNone(self.merger._maybe_cast_int(None))
self.assertIsNone(self.merger._maybe_cast_int("invalid"))
self.assertEqual(self.merger._maybe_cast_int("123"), 123)
self.assertEqual(self.merger._maybe_cast_int(456), 456)
def test_mixed_rank_scenarios(self):
trace_scenarios = [
{
"filename": f"{self.profile_id}-TP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op1", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
},
{
"filename": f"{self.profile_id}-TP-1-DP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op2", "pid": 84, "ts": 2000.0, "dur": 15.0}
],
},
{
"filename": f"{self.profile_id}-TP-0-DP-1-PP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op3", "pid": 85, "ts": 3000.0, "dur": 20.0}
],
},
]
for scenario in trace_scenarios:
filepath = os.path.join(self.temp_dir, scenario["filename"])
trace_data = {
"schemaVersion": 1,
"deviceProperties": [{"device_id": 0, "name": "GPU-0"}],
"traceEvents": scenario["events"],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_data, f)
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 3)
events = merged_data["traceEvents"]
pids = [event["pid"] for event in events]
self.assertIn("[TP00] 83", pids)
self.assertIn("[TP01-DP00] 84", pids)
self.assertIn("[TP00-DP01-PP00] 85", pids)
if __name__ == "__main__":
unittest.main()
import json
import unittest
from sglang.srt.managers.io_struct import ProfileReqInput
class TestProfileMergerHTTPAPI(unittest.TestCase):
def test_profile_req_input_merge_profiles_json_serialization(self):
# Test with merge_profiles=True
req_input = ProfileReqInput(
output_dir="/tmp/test",
num_steps=5,
activities=["CPU", "GPU"],
profile_by_stage=True,
merge_profiles=True,
)
# Convert to dict (as would happen in HTTP request)
req_dict = {
"output_dir": req_input.output_dir,
"num_steps": req_input.num_steps,
"activities": req_input.activities,
"profile_by_stage": req_input.profile_by_stage,
"merge_profiles": req_input.merge_profiles,
}
# Test JSON serialization
json_str = json.dumps(req_dict)
parsed_data = json.loads(json_str)
self.assertTrue(parsed_data["merge_profiles"])
self.assertEqual(parsed_data["output_dir"], "/tmp/test")
self.assertEqual(parsed_data["num_steps"], 5)
self.assertEqual(parsed_data["activities"], ["CPU", "GPU"])
self.assertTrue(parsed_data["profile_by_stage"])
def test_profile_req_input_merge_profiles_json_deserialization(self):
# Test JSON data as would come from HTTP request
json_data = {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": False,
"merge_profiles": True,
}
# Create ProfileReqInput from dict (as HTTP server would do)
req_input = ProfileReqInput(**json_data)
self.assertTrue(req_input.merge_profiles)
self.assertEqual(req_input.output_dir, "/tmp/test")
self.assertEqual(req_input.num_steps, 10)
self.assertEqual(req_input.activities, ["CPU", "GPU", "MEM"])
self.assertFalse(req_input.profile_by_stage)
def test_profile_req_input_merge_profiles_default_value(self):
# Test with minimal data
json_data = {"output_dir": "/tmp/test"}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
def test_profile_req_input_merge_profiles_explicit_false(self):
json_data = {"output_dir": "/tmp/test", "merge_profiles": False}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
def test_http_api_parameter_flow(self):
# Simulate HTTP request data
request_data = {
"output_dir": "/tmp/test",
"num_steps": 5,
"activities": ["CPU", "GPU"],
"profile_by_stage": True,
"merge_profiles": True,
}
# Create ProfileReqInput as HTTP server would
obj = ProfileReqInput(**request_data)
# Verify the parameter is set correctly
self.assertTrue(obj.merge_profiles)
self.assertEqual(obj.output_dir, "/tmp/test")
self.assertEqual(obj.num_steps, 5)
self.assertEqual(obj.activities, ["CPU", "GPU"])
self.assertTrue(obj.profile_by_stage)
def test_http_api_parameter_validation(self):
# Test with True
json_data = {"merge_profiles": True}
req_input = ProfileReqInput(**json_data)
self.assertTrue(req_input.merge_profiles)
# Test with False
json_data = {"merge_profiles": False}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
# Test with string "true" (should be converted by JSON parser)
json_data = {"merge_profiles": "true"}
req_input = ProfileReqInput(**json_data)
self.assertEqual(req_input.merge_profiles, "true") # String, not boolean
def test_http_api_backward_compatibility(self):
# Test minimal request (no merge_profiles)
json_data = {}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles) # Should default to False
# Test with other parameters but no merge_profiles
json_data = {
"output_dir": "/tmp/test",
"num_steps": 5,
"activities": ["CPU", "GPU"],
}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles) # Should default to False
def test_http_api_parameter_combinations(self):
test_cases = [
{
"name": "minimal with merge_profiles",
"data": {"merge_profiles": True},
"expected_merge": True,
},
{
"name": "full parameters with merge_profiles=True",
"data": {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": True,
"with_stack": True,
"record_shapes": True,
"merge_profiles": True,
},
"expected_merge": True,
},
{
"name": "full parameters with merge_profiles=False",
"data": {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": False,
"with_stack": False,
"record_shapes": False,
"merge_profiles": False,
},
"expected_merge": False,
},
]
for test_case in test_cases:
with self.subTest(test_case["name"]):
req_input = ProfileReqInput(**test_case["data"])
self.assertEqual(req_input.merge_profiles, test_case["expected_merge"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment