cli.py 905 Bytes
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import typer

from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional

from bloom_inference import server

app = typer.Typer()


@app.command()
def launcher(
        model_name: str,
        num_gpus: int = 1,
        shard_directory: Optional[Path] = None,
):
    if num_gpus == 1:
        serve(model_name, False, shard_directory)

    else:
        config = LaunchConfig(
            min_nodes=1,
            max_nodes=1,
            nproc_per_node=num_gpus,
            rdzv_backend="c10d",
            max_restarts=0,
        )
        launch_agent(config, server.serve, [model_name, True, shard_directory])


@app.command()
def serve(
        model_name: str,
        sharded: bool = False,
        shard_directory: Optional[Path] = None,
):
    server.serve(model_name, sharded, shard_directory)


if __name__ == "__main__":
    app()