cli.py 4.52 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
10
11
12
13


app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
def serve(
15
    model_id: str,
16
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    sharded: bool = False,
18
    quantize: bool = False,
19
    uds_path: Path = "/tmp/text-generation-server",
20
21
    logger_level: str = "INFO",
    json_output: bool = False,
22
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
23
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

38
39
40
41
42
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
43
        filter="text_generation_server",
44
45
46
47
48
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
49
50
51
52
53

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

54
55
56
57
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

58
    server.serve(model_id, revision, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
61


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
62
def download_weights(
63
    model_id: str,
64
    revision: Optional[str] = None,
65
    extension: str = ".safetensors",
66
    auto_convert: bool = True,
67
68
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
69
):
70
71
72
73
74
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
75
        filter="text_generation_server",
76
77
78
79
80
81
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

82
83
84
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

85
86
87
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
88
        logger.info("Files are already present on the host. " "Skipping download.")
89
90
        return
    # Local files not found
91
    except (utils.LocalEntryNotFoundError, FileNotFoundError):
92
93
        pass

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

    # Try to see if there are local pytorch weights
113
    try:
114
115
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
116

117
118
119
120
121
122
123
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
124

125
        # Try to see if there are pytorch weights on the hub
126
127
128
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
129
130
131
132
133
134
135
136

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
137
138
139
140
141
142
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
        # Convert pytorch weights to safetensors
        utils.convert_files(local_pt_files, local_st_files)
Olivier Dehaene's avatar
Olivier Dehaene committed
143
144
145
146


if __name__ == "__main__":
    app()