"vscode:/vscode.git/clone" did not exist on "84290a10941f0a8ff28f50d42aa6f3dfa1054ddf"
cli.py 12.1 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
3
4
5
6
7
8
9
import os
import sys
import typer

from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
jixx's avatar
jixx committed
10
from text_generation_server.utils.adapter import parse_lora_adapters
jixx's avatar
init  
jixx committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


app = typer.Typer()


class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
    gptq = "gptq"
    awq = "awq"
    eetq = "eetq"
    exl2 = "exl2"
    fp8 = "fp8"
    marlin = "marlin"


class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


jixx's avatar
jixx committed
33
34
35
36
37
class KVCacheDtype(str, Enum):
    fp8_e4m3fn = "fp8_e4m3fn"
    fp8_e5m2 = "fp8_e5m2"


jixx's avatar
init  
jixx committed
38
39
40
41
42
43
44
45
@app.command()
def serve(
    model_id: str,
    revision: Optional[str] = None,
    sharded: bool = False,
    quantize: Optional[Quantization] = None,
    speculate: Optional[int] = None,
    dtype: Optional[Dtype] = None,
jixx's avatar
jixx committed
46
    kv_cache_dtype: Optional[KVCacheDtype] = None,
jixx's avatar
init  
jixx committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    trust_remote_code: bool = False,
    uds_path: Path = "/tmp/text-generation-server",
    logger_level: str = "INFO",
    json_output: bool = False,
    otlp_endpoint: Optional[str] = None,
    otlp_service_name: str = "text-generation-inference.server",
    max_input_tokens: Optional[int] = None,
):
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation_server",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)

jixx's avatar
jixx committed
89
    lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
jixx's avatar
init  
jixx committed
90

jixx's avatar
jixx committed
91
92
93
94
    # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
    # and warn the user
    if lora_adapters:
        logger.warning("LoRA adapters enabled (experimental feature).")
jixx's avatar
init  
jixx committed
95

jixx's avatar
jixx committed
96
97
98
99
100
101
        if "CUDA_GRAPHS" in os.environ:
            logger.warning(
                "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
            )
            global CUDA_GRAPHS
            CUDA_GRAPHS = None
jixx's avatar
init  
jixx committed
102
103
104
105

    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
    dtype = None if dtype is None else dtype.value
jixx's avatar
jixx committed
106
    kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
jixx's avatar
init  
jixx committed
107
108
109
110
111
112
113
114
115
116
117
    if dtype is not None and quantize not in {
        None,
        "bitsandbytes",
        "bitsandbytes-nf4",
        "bitsandbytes-fp4",
    }:
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
        model_id,
jixx's avatar
jixx committed
118
        lora_adapters,
jixx's avatar
init  
jixx committed
119
120
121
122
123
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
jixx's avatar
jixx committed
124
        kv_cache_dtype,
jixx's avatar
init  
jixx committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        trust_remote_code,
        uds_path,
        max_input_tokens,
    )


@app.command()
def download_weights(
    model_id: str,
    revision: Optional[str] = None,
    extension: str = ".safetensors",
    auto_convert: bool = True,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    merge_lora: bool = False,
):
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation_server",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
        logger.info("Files are already present on the host. " "Skipping download.")
        return
    # Local files not found
    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
        pass

    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
        # TODO: maybe reverse the default value of merge_lora?
        # currently by default we don't merge the weights with the base model
        if merge_lora:
            try:
jixx's avatar
jixx committed
175
                hf_hub_download(
jixx's avatar
init  
jixx committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
                    model_id, revision=revision, filename="adapter_config.json"
                )
                utils.download_and_unload_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
                is_local_model = True
                utils.weight_files(model_id, revision, extension)
                return
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
        else:
            try:
                utils.peft.download_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
            except Exception:
                pass

        try:
            import json

            config = hf_hub_download(
                model_id, revision=revision, filename="config.json"
            )
            with open(config, "r") as f:
                config = json.load(f)

            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id and base_model_id != model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

    elif (Path(model_id) / "adapter_config.json").exists():
        # Try to load as a local PEFT model
        try:
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
            utils.weight_files(model_id, revision, extension)
            return
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
    elif (Path(model_id) / "config.json").exists():
        # Try to load as a local Medusa model
        try:
            import json

            config = Path(model_id) / "config.json"
            with open(config, "r") as f:
                config = json.load(f)

            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

    # Try to see if there are local pytorch weights
    try:
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        try:
            local_pt_files = utils.weight_files(model_id, revision, ".bin")
        except Exception:
            local_pt_files = utils.weight_files(model_id, revision, ".pt")

    # No local pytorch weights
    except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )

        # Try to see if there are pytorch weights on the hub
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)

    if auto_convert:
        if not trust_remote_code:
            logger.warning(
jixx's avatar
jixx committed
295
296
297
                "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
                "Pickle files are unsafe and can essentially contain remote code execution!"
                "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
jixx's avatar
init  
jixx committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            )

        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
        try:
            import transformers
            import json

            if is_local_model:
                config_filename = os.path.join(model_id, "config.json")
            else:
                config_filename = hf_hub_download(
                    model_id, revision=revision, filename="config.json"
                )
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])

jixx's avatar
jixx committed
329
        except Exception:
jixx's avatar
init  
jixx committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            discard_names = []
        # Convert pytorch weights to safetensors
        utils.convert_files(local_pt_files, local_st_files, discard_names)


@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
jixx's avatar
jixx committed
346
    groupsize: int = 128,
jixx's avatar
init  
jixx committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
):
    if revision is None:
        revision = "main"
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.layers.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
jixx's avatar
jixx committed
361
        groupsize=groupsize,
jixx's avatar
init  
jixx committed
362
363
364
365
366
367
        output_dir=output_dir,
        revision=revision,
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
jixx's avatar
jixx committed
368
        sym=True,
jixx's avatar
init  
jixx committed
369
370
371
372
373
    )


if __name__ == "__main__":
    app()