cli.py 1.28 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
4
5
6
import typer

from pathlib import Path
from typing import Optional

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
7
from bloom_inference import prepare_weights, server
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
10
11
12

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
14
15
16
17
def serve(
    model_name: str,
    sharded: bool = False,
    shard_directory: Optional[Path] = None,
    uds_path: Path = "/tmp/bloom-inference",
Olivier Dehaene's avatar
Olivier Dehaene committed
18
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    if sharded:
        assert (
            shard_directory is not None
        ), "shard_directory must be set when sharded is True"
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

    server.serve(model_name, sharded, uds_path, shard_directory)
Olivier Dehaene's avatar
Olivier Dehaene committed
37
38
39


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
41
42
43
44
def prepare_weights(
    model_name: str,
    shard_directory: Path,
    cache_directory: Path,
    num_shard: int = 1,
Olivier Dehaene's avatar
Olivier Dehaene committed
45
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
47
48
    prepare_weights.prepare_weights(
        model_name, cache_directory, shard_directory, num_shard
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
49
50
51
52


if __name__ == "__main__":
    app()