rccl_log_parser.py 9.14 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#!/usr/bin/env python3

import sys
import os
import subprocess
import re
import pandas as pd


class RcclLogParser:
    def __init__(self):
        self.output = set()
        self.raw_lines = set()

        # Pattern -> output string or as-is
        self.sys_patterns = {
            r"kernel version": None,
            r"ROCr version": None,
            r"RCCL version": None,
            r"Librccl path": None,
            r"iommu": None,
            r"Dmabuf feature disabled": "Dmabuf: disabled",
            r"Disabled GDRCopy": "GDRCopy: disabled",
        }

26
        # Pattern -> column
one's avatar
one committed
27
28
29
30
31
32
33
34
35
        self.graph_info_fields = {
            r"Pattern": "Pattern",
            r"crossNic": "crossNic",
            r"nChannels": "nChannels",
            r"bw": "bandwidth",
            r"type": "type",
            r"sameChannels": "sameChannels",
        }

36
37
        # Pattern -> column
        self.cl_transfer_fields = {
one's avatar
one committed
38
39
40
41
42
43
44
45
46
47
            r"protocol": "protocol",
            r"nbytes": "nbytes",
            r"algorithm": "algorithm",
            r"slicesteps": "slicesteps",
            r"nchannels": "nchannels",
            r"nloops": "nloops",
            r"nsteps": "nsteps",
            r"chunksize": "chunksize",
        }

48
49
50
51
52
53
54
55
56
57
58
        # Pattern -> column
        self.p2p_fields = {
            r"p2p : rank": "local",
            r"send rank": "send",
            r"recv rank": "recv",
            r"p2pnChannelsPerPeer": "p2pnChannelsPerPeer",
            r"p2pnChannels": "p2pnChannels",
            r"nChannelsMax": "nChannelsMax",
            r"protocol": "protocol",
        }

one's avatar
one committed
59
60
61
62
63
64
65
66
67
68
69
70
71
    def collect(self, line):
        self.raw_lines.add(line)

    def report(self):
        print(" RCCL Log Parser Report ".center(80, "="))
        print()

        for line in self.raw_lines:
            self._preprocess_line(line)

        self._report_sys()
        self._report_user_envs()
        self._report_graph_info()
72
73
        self._report_cl_transfers()
        self._report_p2p_transfers()
one's avatar
one committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

        print(" End of Report ".center(80, "="))

    def _preprocess_line(self, line):
        match = re.search(r"\[\d+\]\s+NCCL\s+(?:INFO|WARN|ERROR)\s+(.*)", line)
        if match:
            self.output.add(match.group(1))

    def _report_sys(self):
        """Search patterns and print pre-defined strings if matched"""
        print("===> System Information:\n")
        reported_lines = []
        for line in self.output:
            for pattern, output in self.sys_patterns.items():
                if re.search(pattern, line, re.IGNORECASE):
                    reported_lines.append(output if output else line)
                    break
        for line in reported_lines:
            print(line)
        print()

    def _report_user_envs(self):
        """Search environment variables set by user"""
        print("===> User-defined Environment Variables:\n")
        pattern = re.compile(r"(\w+)\s+set by environment to\s+(.+)")
        for line in self.output:
            m = pattern.search(line)
            if m:
                print(f"{m.group(1)}: {m.group(2)}")
        print()

    def _report_graph_info(self):
        """Extract graph information (Optimized)"""
        print("===> Graph Info:\n")

        # Filter lines by looking for 'Pattern' and 'crossNic'
        filtered_lines = [
            line for line in self.output if "Pattern" in line and "crossNic" in line
        ]

        if not filtered_lines:
            print("  (No graph info found)\n")
            return

        df = pd.DataFrame(filtered_lines, columns=["raw_log"])

        # Extract all fields using a single regex
        regex_parts = []
        for key, col_name in self.graph_info_fields.items():
            regex_parts.append(rf"{key}\s+(?P<{col_name}>[^,\s]+)")

        # Join all parts with .*? to match any characters between fields
        full_regex = r".*?".join(regex_parts)

        extracted_df = df["raw_log"].str.extract(full_regex)

        # Clean up and convert to numeric
        if "Pattern" in extracted_df.columns:
            extracted_df["Pattern"] = pd.to_numeric(
                extracted_df["Pattern"], errors="coerce"
            )

        extracted_df.drop_duplicates(inplace=True)
        extracted_df.sort_values(by="Pattern", ascending=False, inplace=True)

        print(extracted_df.fillna("-").to_string(index=False))
        print()

142
143
144
    def _report_cl_transfers(self):
        """Extract non-P2P transfer arguments"""
        print("===> Unique Ring/Tree Transfers:\n")
one's avatar
one committed
145
146
147
148
149
150
151
152
153
154
155
156
157

        # Filter lines by looking for 'protocol' and 'nbytes'
        raw_lines = [
            line for line in self.output if "protocol" in line and "nbytes" in line
        ]

        if not raw_lines:
            print("  (No transfer patterns found)\n")
            return

        df = pd.DataFrame(raw_lines, columns=["raw_log"])

        # Extract all fields using a single loop
158
        for pattern, col_name in self.cl_transfer_fields.items():
one's avatar
one committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            df[col_name] = df["raw_log"].str.extract(
                rf"\b{pattern}\s+(\S+)", expand=False
            )

        # Type conversion for correct sorting
        for field in ["nbytes", "nchannels"]:
            if field in df.columns:
                df[field] = pd.to_numeric(df[field], errors="coerce")

        # Drop rows where mandatory fields are missing
        mandatory_cols = [c for c in ["protocol", "nbytes"] if c in df.columns]
        df.dropna(subset=mandatory_cols, inplace=True)

        # Clean up
        df.drop(columns=["raw_log"], inplace=True)
        df.drop_duplicates(inplace=True)

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        sort_cols = ["nbytes", "protocol", "nchannels"]
        sort_cols = [col for col in sort_cols if col in df.columns]

        if sort_cols:
            df.sort_values(by=sort_cols, inplace=True)

        # Fill NaNs with "-" and print
        print(df.fillna("-").to_string(index=False))
        print()

    def _report_p2p_transfers(self):
        """Extract P2P transfer details"""
        print("===> Unique P2P Transfers:\n")

        # Filter lines by looking for 'p2p :' and 'send rank'
        raw_lines = [
            line for line in self.output if "p2p :" in line and "send rank" in line
        ]

        if not raw_lines:
            print("  (No P2P transfers found)\n")
            return

        # Extract all fields using a single loop
        df = pd.DataFrame(raw_lines, columns=["raw_log"])
        for pattern, col_name in self.p2p_fields.items():
            df[col_name] = df["raw_log"].str.extract(
                rf"{pattern}\s+(\S+)", expand=False
            )
one's avatar
one committed
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # Type conversion for correct sorting
        numeric_cols = [
            "local",
            "send",
            "recv",
            "p2pnChannelsPerPeer",
            "p2pnChannels",
            "nChannelsMax",
        ]
        for col in numeric_cols:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors="coerce")

        # Clean up
        df.drop(columns=["raw_log"], inplace=True)
        df.drop_duplicates(inplace=True)

        sort_cols = ["protocol", "local", "send", "recv"]
        sort_cols = [c for c in sort_cols if c in df.columns]
one's avatar
one committed
225
226
227
        if sort_cols:
            df.sort_values(by=sort_cols, inplace=True)

228
229
230
231
232
233
234
        # Move 'protocol' to the first column
        cols = df.columns.tolist()
        if "protocol" in cols:
            cols.remove("protocol")
            cols.insert(0, "protocol")
            df = df[cols]

one's avatar
one committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        # Fill NaNs with "-" and print
        print(df.fillna("-").to_string(index=False))
        print()


def get_mpi_rank():
    """
    Try to get Rank ID from common environment variables.
    If not found, return "0".
    """
    # Common MPI Rank environment variables
    rank_vars = [
        "OMPI_COMM_WORLD_RANK",  # OpenMPI
        "PMI_RANK",  # MPICH / MVAPICH
        "SLURM_PROCID",  # Slurm
        "RANK",  # General / Torch
    ]

    for var in rank_vars:
        if var in os.environ:
            return int(os.environ[var])
    return 0


def main():
    rank = get_mpi_rank()
    log_prefix = f"[Rank {rank}]"

    # Only print usage when rank is 0 or not specified
    if len(sys.argv) < 2 and rank == 0:
        script_name = os.path.basename(__file__)
        print(f"Usage: python {script_name} <executable> [arguments...]")
        sys.exit(1)

    # Get the command and environment variables
    cmd = sys.argv[1:]
    env = os.environ.copy()

    # Inject RCCL environment variables
    env["NCCL_DEBUG"] = "INFO"
    env["NCCL_DEBUG_SUBSYS"] = "ALL"

    print(f"{log_prefix} [Wrapper] Running command: {' '.join(cmd)}")

    try:
        parser = RcclLogParser()
        process = subprocess.Popen(
            cmd,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
        )

        # Collect all output lines
        for line in process.stdout:
            print(f"{line}", end="", flush=True)
            parser.collect(line)

        process.wait()

        if rank == 0:
            parser.report()

        sys.exit(process.returncode)
    except KeyboardInterrupt:
        sys.exit(130)
    except FileNotFoundError:
        print(f"{log_prefix} Error: Command not found: {cmd[0]}")
        sys.exit(1)


if __name__ == "__main__":
    main()