cli.py 1.71 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
Olivier Dehaene's avatar
Olivier Dehaene committed
8

9
from text_generation import server, utils
10
from text_generation.tracing import setup_tracing
Olivier Dehaene's avatar
Olivier Dehaene committed
11
12
13
14
15

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
def serve(
17
    model_id: str,
18
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
    sharded: bool = False,
20
    quantize: bool = False,
21
    uds_path: Path = "/tmp/text-generation",
22
23
    logger_level: str = "INFO",
    json_output: bool = False,
24
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
25
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

55
    server.serve(model_id, revision, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
56
57
58


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
59
def download_weights(
60
    model_id: str,
61
    revision: Optional[str] = None,
62
    extension: str = ".safetensors",
Olivier Dehaene's avatar
Olivier Dehaene committed
63
):
64
    utils.download_weights(model_id, revision, extension)
Olivier Dehaene's avatar
Olivier Dehaene committed
65
66
67
68


if __name__ == "__main__":
    app()