main.py 2.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3

import argparse
import os
import subprocess
import sys

from .parser import RcclLogParser


def get_mpi_rank():
    """
    Try to get Rank ID from common environment variables.
    If not found, return "0".
    """
    # Common MPI Rank environment variables
    rank_vars = [
        "OMPI_COMM_WORLD_RANK",  # OpenMPI
        "PMI_RANK",  # MPICH / MVAPICH
        "SLURM_PROCID",  # Slurm
        "RANK",  # General / Torch
    ]

    for var in rank_vars:
        if var in os.environ:
            return int(os.environ[var])
    return 0


def main():
    rank = get_mpi_rank()
    log_prefix = f"[Rank {rank}]"

    # Parse command line arguments
    parser = argparse.ArgumentParser(description="RCCL Log Parser Wrapper")
    parser.add_argument(
        "-v", "--verbose", action="store_true", help="Print raw log lines in addition to the report"
    )
    parser.add_argument(
        "command", nargs=argparse.REMAINDER, help="The executable and arguments to run"
    )

    args = parser.parse_args()

    verbose = args.verbose
    cmd = args.command

    # Check if command is provided
    if not cmd and rank == 0:
        parser.print_help()
        sys.exit(1)

    # Get the environment variables
    env = os.environ.copy()

    # Inject RCCL environment variables
    env["NCCL_DEBUG"] = "INFO"
    env["NCCL_DEBUG_SUBSYS"] = "ALL"

    print(f"{log_prefix} [Wrapper] Running command: {' '.join(cmd)}")

    try:
        parser = RcclLogParser()
        process = subprocess.Popen(
            cmd,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
        )

        # Collect all output lines
        for line in process.stdout:
            if verbose:
                print(f"{line}", end="", flush=True)
            parser.collect(line)

        process.wait()

        if rank == 0:
            parser.report()

        sys.exit(process.returncode)
    except KeyboardInterrupt:
        sys.exit(130)
    except FileNotFoundError:
        print(f"{log_prefix} Error: Command not found: {cmd[0]}")
        sys.exit(1)


if __name__ == "__main__":
    main()