cli.py 947 Bytes
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
4
5
import typer

from pathlib import Path

Nicolas Patry's avatar
Nicolas Patry committed
6
from bloom_inference import server, utils
Olivier Dehaene's avatar
Olivier Dehaene committed
7
8
9
10
11

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
14
15
def serve(
    model_name: str,
    sharded: bool = False,
    uds_path: Path = "/tmp/bloom-inference",
Olivier Dehaene's avatar
Olivier Dehaene committed
16
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

Nicolas Patry's avatar
Nicolas Patry committed
31
    server.serve(model_name, sharded, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
32
33
34


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
35
def download_weights(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
    model_name: str,
Olivier Dehaene's avatar
Olivier Dehaene committed
37
):
Nicolas Patry's avatar
Nicolas Patry committed
38
    utils.download_weights(model_name)
Olivier Dehaene's avatar
Olivier Dehaene committed
39
40
41
42


if __name__ == "__main__":
    app()