#!/usr/bin/env python3 import sys import os import subprocess import re import pandas as pd class RcclLogParser: def __init__(self): self.output = set() self.raw_lines = set() # Pattern -> output string or as-is self.sys_patterns = { r"kernel version": None, r"ROCr version": None, r"RCCL version": None, r"Librccl path": None, r"iommu": None, r"Dmabuf feature disabled": "Dmabuf: disabled", r"Disabled GDRCopy": "GDRCopy: disabled", } # Pattern -> replacement self.graph_info_fields = { r"Pattern": "Pattern", r"crossNic": "crossNic", r"nChannels": "nChannels", r"bw": "bandwidth", r"type": "type", r"sameChannels": "sameChannels", } # Pattern -> replacement self.transfer_fields = { r"protocol": "protocol", r"nbytes": "nbytes", r"algorithm": "algorithm", r"slicesteps": "slicesteps", r"nchannels": "nchannels", r"nloops": "nloops", r"nsteps": "nsteps", r"chunksize": "chunksize", } def collect(self, line): self.raw_lines.add(line) def report(self): print(" RCCL Log Parser Report ".center(80, "=")) print() for line in self.raw_lines: self._preprocess_line(line) self._report_sys() self._report_user_envs() self._report_graph_info() self._report_transfers() print(" End of Report ".center(80, "=")) def _preprocess_line(self, line): match = re.search(r"\[\d+\]\s+NCCL\s+(?:INFO|WARN|ERROR)\s+(.*)", line) if match: self.output.add(match.group(1)) def _report_sys(self): """Search patterns and print pre-defined strings if matched""" print("===> System Information:\n") reported_lines = [] for line in self.output: for pattern, output in self.sys_patterns.items(): if re.search(pattern, line, re.IGNORECASE): reported_lines.append(output if output else line) break for line in reported_lines: print(line) print() def _report_user_envs(self): """Search environment variables set by user""" print("===> User-defined Environment Variables:\n") pattern = re.compile(r"(\w+)\s+set by environment to\s+(.+)") for line in self.output: m = pattern.search(line) if m: print(f"{m.group(1)}: {m.group(2)}") print() def _report_graph_info(self): """Extract graph information (Optimized)""" print("===> Graph Info:\n") # Filter lines by looking for 'Pattern' and 'crossNic' filtered_lines = [ line for line in self.output if "Pattern" in line and "crossNic" in line ] if not filtered_lines: print(" (No graph info found)\n") return df = pd.DataFrame(filtered_lines, columns=["raw_log"]) # Extract all fields using a single regex regex_parts = [] for key, col_name in self.graph_info_fields.items(): regex_parts.append(rf"{key}\s+(?P<{col_name}>[^,\s]+)") # Join all parts with .*? to match any characters between fields full_regex = r".*?".join(regex_parts) extracted_df = df["raw_log"].str.extract(full_regex) # Clean up and convert to numeric if "Pattern" in extracted_df.columns: extracted_df["Pattern"] = pd.to_numeric( extracted_df["Pattern"], errors="coerce" ) extracted_df.drop_duplicates(inplace=True) extracted_df.sort_values(by="Pattern", ascending=False, inplace=True) print(extracted_df.fillna("-").to_string(index=False)) print() def _report_transfers(self): """Extract transfer arguments""" print("===> Unique Transfers:\n") # Filter lines by looking for 'protocol' and 'nbytes' raw_lines = [ line for line in self.output if "protocol" in line and "nbytes" in line ] if not raw_lines: print(" (No transfer patterns found)\n") return df = pd.DataFrame(raw_lines, columns=["raw_log"]) # Extract all fields using a single loop for pattern, col_name in self.transfer_fields.items(): df[col_name] = df["raw_log"].str.extract( rf"\b{pattern}\s+(\S+)", expand=False ) # Type conversion for correct sorting for field in ["nbytes", "nchannels"]: if field in df.columns: df[field] = pd.to_numeric(df[field], errors="coerce") # Drop rows where mandatory fields are missing mandatory_cols = [c for c in ["protocol", "nbytes"] if c in df.columns] df.dropna(subset=mandatory_cols, inplace=True) # Clean up df.drop(columns=["raw_log"], inplace=True) df.drop_duplicates(inplace=True) desired_order = ["nbytes", "protocol", "nchannels"] sort_cols = [col for col in desired_order if col in df.columns] if sort_cols: df.sort_values(by=sort_cols, inplace=True) # Fill NaNs with "-" and print print(df.fillna("-").to_string(index=False)) print() def get_mpi_rank(): """ Try to get Rank ID from common environment variables. If not found, return "0". """ # Common MPI Rank environment variables rank_vars = [ "OMPI_COMM_WORLD_RANK", # OpenMPI "PMI_RANK", # MPICH / MVAPICH "SLURM_PROCID", # Slurm "RANK", # General / Torch ] for var in rank_vars: if var in os.environ: return int(os.environ[var]) return 0 def main(): rank = get_mpi_rank() log_prefix = f"[Rank {rank}]" # Only print usage when rank is 0 or not specified if len(sys.argv) < 2 and rank == 0: script_name = os.path.basename(__file__) print(f"Usage: python {script_name} [arguments...]") sys.exit(1) # Get the command and environment variables cmd = sys.argv[1:] env = os.environ.copy() # Inject RCCL environment variables env["NCCL_DEBUG"] = "INFO" env["NCCL_DEBUG_SUBSYS"] = "ALL" print(f"{log_prefix} [Wrapper] Running command: {' '.join(cmd)}") try: parser = RcclLogParser() process = subprocess.Popen( cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) # Collect all output lines for line in process.stdout: print(f"{line}", end="", flush=True) parser.collect(line) process.wait() if rank == 0: parser.report() sys.exit(process.returncode) except KeyboardInterrupt: sys.exit(130) except FileNotFoundError: print(f"{log_prefix} Error: Command not found: {cmd[0]}") sys.exit(1) if __name__ == "__main__": main()