utils.py 1.82 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# SPDX-License-Identifier: Apache-2.0

import os
import subprocess
import sys
from typing import List, Optional

from fastvideo.v1.logger import init_logger

logger = init_logger(__name__)


def launch_distributed(num_gpus: int,
                       args: List[str],
                       master_port: Optional[int] = None) -> int:
    """
    Launch a distributed job with the given arguments
    
    Args:
        num_gpus: Number of GPUs to use
        args: Arguments to pass to v1_fastvideo_inference.py (defaults to sys.argv[1:])
        master_port: Port for the master process (default: random)
    """

    current_env = os.environ.copy()
    python_executable = sys.executable
    project_root = os.path.abspath(
        os.path.join(os.path.dirname(__file__), "../../../.."))
    main_script = os.path.join(project_root,
                               "fastvideo/v1/sample/v1_fastvideo_inference.py")

    cmd = [
        python_executable, "-m", "torch.distributed.run",
        f"--nproc_per_node={num_gpus}"
    ]

    if master_port is not None:
        cmd.append(f"--master_port={master_port}")

    cmd.append(main_script)
    cmd.extend(args)

    logger.info("Running inference with %d GPU(s)", num_gpus)
    logger.info("Launching command: %s", " ".join(cmd))

    current_env["PYTHONIOENCODING"] = "utf-8"
    process = subprocess.Popen(cmd,
                               env=current_env,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.STDOUT,
                               universal_newlines=True,
                               bufsize=1,
                               encoding='utf-8',
                               errors='replace')

    if process.stdout:
        for line in iter(process.stdout.readline, ''):
            print(line.strip())

    return process.wait()