cli.py 6.59 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
Olivier Dehaene's avatar
Olivier Dehaene committed
9
10
11
12
13


app = typer.Typer()


14
15
16
17
18
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
    gptq = "gptq"


19
20
21
22
23
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
24
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
def serve(
26
    model_id: str,
27
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
    sharded: bool = False,
29
    quantize: Optional[Quantization] = None,
30
    dtype: Optional[Dtype] = None,
31
    trust_remote_code: bool = False,
32
    uds_path: Path = "/tmp/text-generation-server",
33
34
    logger_level: str = "INFO",
    json_output: bool = False,
35
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
36
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

51
52
53
54
55
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
56
        filter="text_generation_server",
57
58
59
60
61
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
62
63
64
65
66

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

67
68
69
70
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

71
72
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
73
74
75
76
77
78
79
80
    dtype = None if dtype is None else dtype.value
    if dtype is not None and quantize is not None:
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
        model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
81
82
83


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
84
def download_weights(
85
    model_id: str,
86
    revision: Optional[str] = None,
87
    extension: str = ".safetensors",
88
    auto_convert: bool = True,
89
90
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
91
):
92
93
94
95
96
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
97
        filter="text_generation_server",
98
99
100
101
102
103
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

104
105
106
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

107
108
109
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
110
        logger.info("Files are already present on the host. " "Skipping download.")
111
112
        return
    # Local files not found
113
    except (utils.LocalEntryNotFoundError, FileNotFoundError):
114
115
        pass

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

    # Try to see if there are local pytorch weights
135
    try:
136
137
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
138

139
140
141
142
143
144
145
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
146

147
        # Try to see if there are pytorch weights on the hub
148
149
150
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
151
152
153
154
155
156
157
158

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
159
160
161
162
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        try:
            from transformers import AutoConfig
            import transformers

            config = AutoConfig.from_pretrained(
                model_id,
                revision=revision,
            )
            architecture = config.architectures[0]

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])
            discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

        except Exception as e:
            discard_names = []
181
        # Convert pytorch weights to safetensors
182
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
183
184


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.utils.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
217
218
if __name__ == "__main__":
    app()