"examples/vscode:/vscode.git/clone" did not exist on "6b0f2e908815acc3dbcf6630b5cdff4b9fbece72"
cli.py 11.7 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
9
from huggingface_hub import hf_hub_download
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
12
13
14


app = typer.Typer()


15
16
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
Nicolas Patry's avatar
Nicolas Patry committed
17
18
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
19
    gptq = "gptq"
20
    awq = "awq"
21
    eetq = "eetq"
xuxzh1's avatar
last  
xuxzh1 committed
22
    exl2 = "exl2"
Nicolas Patry's avatar
Nicolas Patry committed
23
    fp8 = "fp8"
xuxzh1's avatar
last  
xuxzh1 committed
24
    marlin = "marlin"
25
26


27
28
29
30
31
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
32
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
def serve(
34
    model_id: str,
35
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
    sharded: bool = False,
37
    quantize: Optional[Quantization] = None,
Nicolas Patry's avatar
Nicolas Patry committed
38
    speculate: Optional[int] = None,
39
    dtype: Optional[Dtype] = None,
40
    trust_remote_code: bool = False,
41
    uds_path: Path = "/tmp/text-generation-server",
42
43
    logger_level: str = "INFO",
    json_output: bool = False,
44
    otlp_endpoint: Optional[str] = None,
xuxzh1's avatar
last  
xuxzh1 committed
45
46
    otlp_service_name: str = "text-generation-inference.server",
    max_input_tokens: Optional[int] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
47
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

62
63
64
65
66
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
67
        filter="text_generation_server",
68
69
70
71
72
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
73
74
75
76
77

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

78
79
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
xuxzh1's avatar
last  
xuxzh1 committed
80
81
82
83
84
85
86
87
88
89
90
91
92
        setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)

    lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)

    # split on comma and strip whitespace
    lora_adapter_ids = (
        [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
    )

    if len(lora_adapter_ids) > 0:
        logger.warning(
            f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
        )
93

94
95
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
96
    dtype = None if dtype is None else dtype.value
OlivierDehaene's avatar
OlivierDehaene committed
97
98
99
100
101
102
    if dtype is not None and quantize not in {
        None,
        "bitsandbytes",
        "bitsandbytes-nf4",
        "bitsandbytes-fp4",
    }:
103
104
105
106
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
OlivierDehaene's avatar
OlivierDehaene committed
107
        model_id,
xuxzh1's avatar
last  
xuxzh1 committed
108
        lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
109
110
111
112
113
114
115
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        uds_path,
xuxzh1's avatar
last  
xuxzh1 committed
116
        max_input_tokens,
117
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
120


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
121
def download_weights(
122
    model_id: str,
123
    revision: Optional[str] = None,
124
    extension: str = ".safetensors",
125
    auto_convert: bool = True,
126
127
    logger_level: str = "INFO",
    json_output: bool = False,
128
    trust_remote_code: bool = False,
xuxzh1's avatar
last  
xuxzh1 committed
129
    merge_lora: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
130
):
131
132
133
134
135
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
136
        filter="text_generation_server",
137
138
139
140
141
142
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

143
144
145
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

146
147
148
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
149
        logger.info("Files are already present on the host. " "Skipping download.")
150
151
        return
    # Local files not found
Nicolas Patry's avatar
Nicolas Patry committed
152
    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
153
154
        pass

155
156
157
158
159
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
xuxzh1's avatar
last  
xuxzh1 committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # TODO: maybe reverse the default value of merge_lora?
        # currently by default we don't merge the weights with the base model
        if merge_lora:
            try:
                adapter_config_filename = hf_hub_download(
                    model_id, revision=revision, filename="adapter_config.json"
                )
                utils.download_and_unload_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
                is_local_model = True
                utils.weight_files(model_id, revision, extension)
                return
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
        else:
            try:
                utils.peft.download_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
            except Exception:
                pass
182

Nicolas Patry's avatar
Nicolas Patry committed
183
184
        try:
            import json
OlivierDehaene's avatar
OlivierDehaene committed
185

xuxzh1's avatar
last  
xuxzh1 committed
186
            config = hf_hub_download(
OlivierDehaene's avatar
OlivierDehaene committed
187
188
                model_id, revision=revision, filename="config.json"
            )
xuxzh1's avatar
last  
xuxzh1 committed
189
            with open(config, "r") as f:
Nicolas Patry's avatar
Nicolas Patry committed
190
191
                config = json.load(f)

xuxzh1's avatar
last  
xuxzh1 committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id and base_model_id != model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
Nicolas Patry's avatar
Nicolas Patry committed
207
208
209
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

210
211
212
213
214
215
216
217
218
219
220
221
222
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

223
    elif (Path(model_id) / "adapter_config.json").exists():
224
225
226
227
228
229
230
231
232
        # Try to load as a local PEFT model
        try:
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
            utils.weight_files(model_id, revision, extension)
            return
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
xuxzh1's avatar
last  
xuxzh1 committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    elif (Path(model_id) / "config.json").exists():
        # Try to load as a local Medusa model
        try:
            import json

            config = Path(model_id) / "config.json"
            with open(config, "r") as f:
                config = json.load(f)

            base_model_id = config.get("base_model_name_or_path", None)
            if base_model_id:
                try:
                    logger.info(f"Downloading parent model {base_model_id}")
                    download_weights(
                        model_id=base_model_id,
                        revision="main",
                        extension=extension,
                        auto_convert=auto_convert,
                        logger_level=logger_level,
                        json_output=json_output,
                        trust_remote_code=trust_remote_code,
                    )
                except Exception:
                    pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
259

260
    # Try to see if there are local pytorch weights
261
    try:
262
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
xuxzh1's avatar
last  
xuxzh1 committed
263
264
265
266
        try:
            local_pt_files = utils.weight_files(model_id, revision, ".bin")
        except Exception:
            local_pt_files = utils.weight_files(model_id, revision, ".pt")
267

268
    # No local pytorch weights
xuxzh1's avatar
last  
xuxzh1 committed
269
    except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
270
271
272
273
274
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
275

276
        # Try to see if there are pytorch weights on the hub
277
278
279
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
280
281

    if auto_convert:
282
283
284
285
286
287
288
        if not trust_remote_code:
            logger.warning(
                f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
                f"Pickle files are unsafe and can essentially contain remote code execution!"
                f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
            )

289
290
291
292
293
294
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
295
296
297
298
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
299
300
        try:
            import transformers
301
            import json
302

303
304
305
            if is_local_model:
                config_filename = os.path.join(model_id, "config.json")
            else:
OlivierDehaene's avatar
OlivierDehaene committed
306
307
308
                config_filename = hf_hub_download(
                    model_id, revision=revision, filename="config.json"
                )
309
310
311
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]
312
313
314
315
316
317
318
319

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])

        except Exception as e:
            discard_names = []
320
        # Convert pytorch weights to safetensors
321
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
322
323


324
325
326
327
328
329
330
331
332
333
334
335
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
336
337
    if revision is None:
        revision = "main"
338
339
340
341
342
343
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
xuxzh1's avatar
last  
xuxzh1 committed
344
    from text_generation_server.layers.gptq.quantize import quantize
345
346
347
348
349
350

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
351
        revision=revision,
352
353
354
355
356
357
358
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
359
360
if __name__ == "__main__":
    app()