cli.py 1.36 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
Olivier Dehaene's avatar
Olivier Dehaene committed
7

8
from text_generation import server, utils
Olivier Dehaene's avatar
Olivier Dehaene committed
9
10
11
12
13

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
15
16
def serve(
    model_name: str,
    sharded: bool = False,
17
    quantize: bool = False,
18
    uds_path: Path = "/tmp/text-generation",
19
20
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
21
):
22
23
24
25
26
27
28
29
30
31
32
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

47
    server.serve(model_name, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
50


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
51
def download_weights(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
    model_name: str,
53
    extension: str = ".safetensors",
Olivier Dehaene's avatar
Olivier Dehaene committed
54
):
55
    utils.download_weights(model_name, extension)
Olivier Dehaene's avatar
Olivier Dehaene committed
56
57
58
59


if __name__ == "__main__":
    app()