cli.py 3.24 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
Olivier Dehaene's avatar
Olivier Dehaene committed
8

9
10
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
Olivier Dehaene's avatar
Olivier Dehaene committed
11
12
13
14
15

app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
def serve(
17
    model_id: str,
18
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
    sharded: bool = False,
20
    quantize: bool = False,
21
    uds_path: Path = "/tmp/text-generation",
22
23
    logger_level: str = "INFO",
    json_output: bool = False,
24
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
25
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

40
41
42
43
44
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
45
        filter="text_generation_server",
46
47
48
49
50
51
52
53
54
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

55
    server.serve(model_id, revision, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
56
57
58


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
59
def download_weights(
60
    model_id: str,
61
    revision: Optional[str] = None,
62
    extension: str = ".safetensors",
63
64
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
65
):
66
67
68
69
70
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
71
        filter="text_generation_server",
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
        logger.info(
            "Files are already present in the local cache. " "Skipping download."
        )
        return
    # Local files not found
    except utils.LocalEntryNotFoundError:
        pass

    # Download weights directly
    try:
        filenames = utils.weight_hub_files(model_id, revision, extension)
        utils.download_weights(filenames, model_id, revision)
    except utils.EntryNotFoundError as e:
        if not extension == ".safetensors":
            raise e

        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights instead."
        )

        # Try to see if there are pytorch weights
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
        # Convert pytorch weights to safetensors
        utils.convert_files(local_pt_files, local_st_files)
Olivier Dehaene's avatar
Olivier Dehaene committed
112
113
114
115


if __name__ == "__main__":
    app()