main.py 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3

import argparse
import os
import subprocess
import sys

from .parser import RcclLogParser


def get_mpi_rank():
    """
    Try to get Rank ID from common environment variables.
    If not found, return "0".
    """
    # Common MPI Rank environment variables
    rank_vars = [
        "OMPI_COMM_WORLD_RANK",  # OpenMPI
        "PMI_RANK",  # MPICH / MVAPICH
        "SLURM_PROCID",  # Slurm
        "RANK",  # General / Torch
    ]

    for var in rank_vars:
        if var in os.environ:
            return int(os.environ[var])
    return 0


def main():
    rank = get_mpi_rank()
    log_prefix = f"[Rank {rank}]"

    # Parse command line arguments
    parser = argparse.ArgumentParser(description="RCCL Log Parser Wrapper")
    parser.add_argument(
one's avatar
one committed
37
        "--raw", action="store_true", help="Print raw log lines in addition to the report"
38
    )
one's avatar
one committed
39
    parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose reports")
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    parser.add_argument(
        "command", nargs=argparse.REMAINDER, help="The executable and arguments to run"
    )

    args = parser.parse_args()

    cmd = args.command

    # Check if command is provided
    if not cmd and rank == 0:
        parser.print_help()
        sys.exit(1)

    # Get the environment variables
    env = os.environ.copy()

    # Inject RCCL environment variables
    env["NCCL_DEBUG"] = "INFO"
    env["NCCL_DEBUG_SUBSYS"] = "ALL"

    print(f"{log_prefix} [Wrapper] Running command: {' '.join(cmd)}")

    try:
        parser = RcclLogParser()
        process = subprocess.Popen(
            cmd,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
        )

        # Collect all output lines
        for line in process.stdout:
one's avatar
one committed
75
            if args.raw:
76
77
78
79
80
81
                print(f"{line}", end="", flush=True)
            parser.collect(line)

        process.wait()

        if rank == 0:
one's avatar
one committed
82
            parser.report(verbose=args.verbose)
83
84
85
86
87
88
89
90
91
92
93

        sys.exit(process.returncode)
    except KeyboardInterrupt:
        sys.exit(130)
    except FileNotFoundError:
        print(f"{log_prefix} Error: Command not found: {cmd[0]}")
        sys.exit(1)


if __name__ == "__main__":
    main()