main.py 672 Bytes
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import typer

from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional

from bloom_inference.server import serve


def main(
    model_name: str,
    num_gpus: int = 1,
    shard_directory: Optional[Path] = None,
):
    if num_gpus == 1:
        serve(model_name, False, shard_directory)

    else:
        config = LaunchConfig(
            min_nodes=1,
            max_nodes=1,
            nproc_per_node=num_gpus,
            rdzv_backend="c10d",
            max_restarts=0,
        )
        launch_agent(config, serve, [model_name, True, shard_directory])


if __name__ == "__main__":
    typer.run(main)