cli.py 12.2 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
9
from huggingface_hub import hf_hub_download
Olivier Dehaene's avatar
Olivier Dehaene committed
10

11
from text_generation_server.utils.log import log_master
Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
14
15

app = typer.Typer()


16
17
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
Nicolas Patry's avatar
Nicolas Patry committed
18
19
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
20
    gptq = "gptq"
21
    awq = "awq"
22
    eetq = "eetq"
23
    exl2 = "exl2"
Nicolas Patry's avatar
Nicolas Patry committed
24
    fp8 = "fp8"
25
    marlin = "marlin"
26
27


28
29
30
31
32
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
33
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
def serve(
35
    model_id: str,
36
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    sharded: bool = False,
38
    quantize: Optional[Quantization] = None,
Nicolas Patry's avatar
Nicolas Patry committed
39
    speculate: Optional[int] = None,
40
    dtype: Optional[Dtype] = None,
41
    trust_remote_code: bool = False,
42
    uds_path: Path = "/tmp/text-generation-server",
43
44
    logger_level: str = "INFO",
    json_output: bool = False,
45
    otlp_endpoint: Optional[str] = None,
46
    otlp_service_name: str = "text-generation-inference.server",
47
    max_input_tokens: Optional[int] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
48
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

63
64
65
66
67
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
68
        filter="text_generation_server",
69
70
71
72
73
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
74
75
76
77
78

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

79
80
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
81
        setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
82

drbh's avatar
drbh committed
83
84
85
86
87
88
89
90
    lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)

    # split on comma and strip whitespace
    lora_adapter_ids = (
        [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
    )

    if len(lora_adapter_ids) > 0:
91
92
93
        log_master(
            logger.warning,
            f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
drbh's avatar
drbh committed
94
95
        )

96
97
98
    # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
    # and warn the user
    if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
99
100
101
        log_master(
            logger.warning,
            f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
102
103
104
105
        )
        global CUDA_GRAPHS
        CUDA_GRAPHS = None

106
107
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
108
    dtype = None if dtype is None else dtype.value
OlivierDehaene's avatar
OlivierDehaene committed
109
110
111
112
113
114
    if dtype is not None and quantize not in {
        None,
        "bitsandbytes",
        "bitsandbytes-nf4",
        "bitsandbytes-fp4",
    }:
115
116
117
118
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
OlivierDehaene's avatar
OlivierDehaene committed
119
        model_id,
drbh's avatar
drbh committed
120
        lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
121
122
123
124
125
126
127
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        uds_path,
128
        max_input_tokens,
129
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
130
131
132


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
133
def download_weights(
134
    model_id: str,
135
    revision: Optional[str] = None,
136
    extension: str = ".safetensors",
137
    auto_convert: bool = True,
138
139
    logger_level: str = "INFO",
    json_output: bool = False,
140
    trust_remote_code: bool = False,
drbh's avatar
drbh committed
141
    merge_lora: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
142
):
143
144
145
146
147
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
148
        filter="text_generation_server",
149
150
151
152
153
154
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

155
156
157
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

158
159
160
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
161
        logger.info("Files are already present on the host. " "Skipping download.")
162
163
        return
    # Local files not found
Nicolas Patry's avatar
Nicolas Patry committed
164
    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
165
166
        pass

167
168
169
170
171
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
drbh's avatar
drbh committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        # TODO: maybe reverse the default value of merge_lora?
        # currently by default we don't merge the weights with the base model
        if merge_lora:
            try:
                adapter_config_filename = hf_hub_download(
                    model_id, revision=revision, filename="adapter_config.json"
                )
                utils.download_and_unload_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
                is_local_model = True
                utils.weight_files(model_id, revision, extension)
                return
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
        else:
            try:
                utils.peft.download_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
            except Exception:
                pass
194

Nicolas Patry's avatar
Nicolas Patry committed
195
196
        try:
            import json
OlivierDehaene's avatar
OlivierDehaene committed
197

198
            config = hf_hub_download(
OlivierDehaene's avatar
OlivierDehaene committed
199
200
                model_id, revision=revision, filename="config.json"
            )
201
            with open(config, "r") as f:
Nicolas Patry's avatar
Nicolas Patry committed
202
203
                config = json.load(f)

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id and base_model_id != model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
Nicolas Patry's avatar
Nicolas Patry committed
219
220
221
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

222
223
224
225
226
227
228
229
230
231
232
233
234
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

235
    elif (Path(model_id) / "adapter_config.json").exists():
236
237
238
239
240
241
242
243
244
        # Try to load as a local PEFT model
        try:
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
            utils.weight_files(model_id, revision, extension)
            return
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    elif (Path(model_id) / "config.json").exists():
        # Try to load as a local Medusa model
        try:
            import json

            config = Path(model_id) / "config.json"
            with open(config, "r") as f:
                config = json.load(f)

            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
271

272
    # Try to see if there are local pytorch weights
273
    try:
274
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
275
276
277
278
        try:
            local_pt_files = utils.weight_files(model_id, revision, ".bin")
        except Exception:
            local_pt_files = utils.weight_files(model_id, revision, ".pt")
279

280
    # No local pytorch weights
281
    except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
282
283
284
285
286
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
287

288
        # Try to see if there are pytorch weights on the hub
289
290
291
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
292
293

    if auto_convert:
294
295
296
297
298
299
300
        if not trust_remote_code:
            logger.warning(
                f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
                f"Pickle files are unsafe and can essentially contain remote code execution!"
                f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
            )

301
302
303
304
305
306
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
307
308
309
310
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
311
312
        try:
            import transformers
313
            import json
314

315
316
317
            if is_local_model:
                config_filename = os.path.join(model_id, "config.json")
            else:
OlivierDehaene's avatar
OlivierDehaene committed
318
319
320
                config_filename = hf_hub_download(
                    model_id, revision=revision, filename="config.json"
                )
321
322
323
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]
324
325
326
327
328
329
330
331

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])

        except Exception as e:
            discard_names = []
332
        # Convert pytorch weights to safetensors
333
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
334
335


336
337
338
339
340
341
342
343
344
345
346
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
347
    groupsize: int = 128,
348
):
349
350
    if revision is None:
        revision = "main"
351
352
353
354
355
356
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
357
    from text_generation_server.layers.gptq.quantize import quantize
358
359
360
361

    quantize(
        model_id=model_id,
        bits=4,
362
        groupsize=groupsize,
363
        output_dir=output_dir,
364
        revision=revision,
365
366
367
368
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
369
        sym=True,
370
371
372
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
373
374
if __name__ == "__main__":
    app()