aws_component.py 1.19 KB
Newer Older
xinghao's avatar
xinghao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import torchx.specs as specs
from torchx.components.dist import ddp


def run_dlrm_main(num_trainers: int = 8, *script_args: str) -> specs.AppDef:
    """
    Args:
        num_trainers: The number of trainers to use.
        script_args: A variable number of parameters to provide dlrm_main.py.
    """
    cwd = os.getcwd()
    entrypoint = os.path.join(cwd, "dlrm_main.py")

    user = os.environ.get("USER")
    image = f"/data/home/{user}"

    if num_trainers > 8 and num_trainers % 8 != 0:
        raise ValueError(
            "Trainer jobs spanning multiple hosts must be in multiples of 8."
        )
    nproc_per_node = 8 if num_trainers >= 8 else num_trainers
    num_replicas = max(num_trainers // 8, 1)

    return ddp(
        *script_args,
        name="train_dlrm",
        image=image,
        # AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/).
        cpu=96,
        gpu=8,
        memMB=-1,
        script=entrypoint,
        j=f"{num_replicas}x{nproc_per_node}",
    )