cli.py 5.59 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
Olivier Dehaene's avatar
Olivier Dehaene committed
9
10
11
12
13


app = typer.Typer()


14
15
16
17
18
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
    gptq = "gptq"


Olivier Dehaene's avatar
Olivier Dehaene committed
19
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
def serve(
21
    model_id: str,
22
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    sharded: bool = False,
24
    quantize: Optional[Quantization] = None,
25
    trust_remote_code: bool = False,
26
    uds_path: Path = "/tmp/text-generation-server",
27
28
    logger_level: str = "INFO",
    json_output: bool = False,
29
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
30
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

45
46
47
48
49
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
50
        filter="text_generation_server",
51
52
53
54
55
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
56
57
58
59
60

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

61
62
63
64
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

65
66
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
67
    server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
68
69
70


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
71
def download_weights(
72
    model_id: str,
73
    revision: Optional[str] = None,
74
    extension: str = ".safetensors",
75
    auto_convert: bool = True,
76
77
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
78
):
79
80
81
82
83
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
84
        filter="text_generation_server",
85
86
87
88
89
90
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

91
92
93
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

94
95
96
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
97
        logger.info("Files are already present on the host. " "Skipping download.")
98
99
        return
    # Local files not found
100
    except (utils.LocalEntryNotFoundError, FileNotFoundError):
101
102
        pass

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

    # Try to see if there are local pytorch weights
122
    try:
123
124
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
125

126
127
128
129
130
131
132
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
133

134
        # Try to see if there are pytorch weights on the hub
135
136
137
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
138
139
140
141
142
143
144
145

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
146
147
148
149
150
151
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
        # Convert pytorch weights to safetensors
        utils.convert_files(local_pt_files, local_st_files)
Olivier Dehaene's avatar
Olivier Dehaene committed
152
153


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.utils.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
186
187
if __name__ == "__main__":
    app()