cli.py 1.01 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
4
5
import typer

from pathlib import Path

6
from text_generation import server, utils
Olivier Dehaene's avatar
Olivier Dehaene committed
7
8
9
10
11

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
14
def serve(
    model_name: str,
    sharded: bool = False,
15
    quantize: bool = False,
16
    uds_path: Path = "/tmp/text-generation",
Olivier Dehaene's avatar
Olivier Dehaene committed
17
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

32
    server.serve(model_name, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
35


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
36
def download_weights(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    model_name: str,
38
    extension: str = ".safetensors",
Olivier Dehaene's avatar
Olivier Dehaene committed
39
):
40
    utils.download_weights(model_name, extension)
Olivier Dehaene's avatar
Olivier Dehaene committed
41
42
43
44


if __name__ == "__main__":
    app()