cli.py 1.48 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
Olivier Dehaene's avatar
Olivier Dehaene committed
8

9
from text_generation import server, utils
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
12
13
14

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
16
def serve(
    model_name: str,
17
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
    sharded: bool = False,
19
    quantize: bool = False,
20
    uds_path: Path = "/tmp/text-generation",
21
22
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
23
):
24
25
26
27
28
29
30
31
32
33
34
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

49
    server.serve(model_name, revision, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
50
51
52


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
53
def download_weights(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
    model_name: str,
55
    revision: Optional[str] = None,
56
    extension: str = ".safetensors",
Olivier Dehaene's avatar
Olivier Dehaene committed
57
):
58
    utils.download_weights(model_name, revision, extension)
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
61
62


if __name__ == "__main__":
    app()