cli.py 10.5 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
9
from huggingface_hub import hf_hub_download
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
12
13
14


app = typer.Typer()


15
16
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
Nicolas Patry's avatar
Nicolas Patry committed
17
18
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
19
    gptq = "gptq"
20
    awq = "awq"
21
    eetq = "eetq"
22
23


24
25
26
27
28
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
29
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
def serve(
31
    model_id: str,
32
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    sharded: bool = False,
34
    quantize: Optional[Quantization] = None,
Nicolas Patry's avatar
Nicolas Patry committed
35
    speculate: Optional[int] = None,
36
    dtype: Optional[Dtype] = None,
37
    trust_remote_code: bool = False,
38
    uds_path: Path = "/tmp/text-generation-server",
39
40
    logger_level: str = "INFO",
    json_output: bool = False,
41
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
42
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

57
58
59
60
61
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
62
        filter="text_generation_server",
63
64
65
66
67
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
68
69
70
71
72

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

73
74
75
76
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

77
78
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
79
    dtype = None if dtype is None else dtype.value
OlivierDehaene's avatar
OlivierDehaene committed
80
81
82
83
84
85
    if dtype is not None and quantize not in {
        None,
        "bitsandbytes",
        "bitsandbytes-nf4",
        "bitsandbytes-fp4",
    }:
86
87
88
89
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
OlivierDehaene's avatar
OlivierDehaene committed
90
91
92
93
94
95
96
97
        model_id,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        uds_path,
98
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
99
100
101


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
102
def download_weights(
103
    model_id: str,
104
    revision: Optional[str] = None,
105
    extension: str = ".safetensors",
106
    auto_convert: bool = True,
107
108
    logger_level: str = "INFO",
    json_output: bool = False,
109
    trust_remote_code: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
110
):
111
112
113
114
115
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
116
        filter="text_generation_server",
117
118
119
120
121
122
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

123
124
125
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

126
127
128
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
129
        logger.info("Files are already present on the host. " "Skipping download.")
130
131
        return
    # Local files not found
Nicolas Patry's avatar
Nicolas Patry committed
132
    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
133
134
        pass

135
136
137
138
139
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
140
        try:
OlivierDehaene's avatar
OlivierDehaene committed
141
142
143
144
145
146
            adapter_config_filename = hf_hub_download(
                model_id, revision=revision, filename="adapter_config.json"
            )
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
147
148
149
            is_local_model = True
            utils.weight_files(model_id, revision, extension)
            return
150
151
152
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

Nicolas Patry's avatar
Nicolas Patry committed
153
154
        try:
            import json
OlivierDehaene's avatar
OlivierDehaene committed
155
156
157
158

            medusa_head = hf_hub_download(
                model_id, revision=revision, filename="medusa_lm_head.pt"
            )
Nicolas Patry's avatar
Nicolas Patry committed
159
            if auto_convert:
OlivierDehaene's avatar
OlivierDehaene committed
160
                medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
Nicolas Patry's avatar
Nicolas Patry committed
161
162
                if not medusa_sf.exists():
                    utils.convert_files([Path(medusa_head)], [medusa_sf], [])
OlivierDehaene's avatar
OlivierDehaene committed
163
164
165
            medusa_config = hf_hub_download(
                model_id, revision=revision, filename="config.json"
            )
Nicolas Patry's avatar
Nicolas Patry committed
166
167
168
169
170
171
172
            with open(medusa_config, "r") as f:
                config = json.load(f)

            model_id = config["base_model_name_or_path"]
            revision = "main"
            try:
                utils.weight_files(model_id, revision, extension)
OlivierDehaene's avatar
OlivierDehaene committed
173
174
175
176
                logger.info(
                    f"Files for parent {model_id} are already present on the host. "
                    "Skipping download."
                )
Nicolas Patry's avatar
Nicolas Patry committed
177
178
                return
            # Local files not found
OlivierDehaene's avatar
OlivierDehaene committed
179
180
181
182
183
            except (
                utils.LocalEntryNotFoundError,
                FileNotFoundError,
                utils.EntryNotFoundError,
            ):
Nicolas Patry's avatar
Nicolas Patry committed
184
185
186
187
                pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

188
189
190
191
192
193
194
195
196
197
198
199
200
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

PYNing's avatar
PYNing committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    elif (Path(model_id) / "medusa_lm_head.pt").exists():
        # Try to load as a local Medusa model
        try:
            import json

            medusa_head = Path(model_id) / "medusa_lm_head.pt"
            if auto_convert:
                medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
                if not medusa_sf.exists():
                    utils.convert_files([Path(medusa_head)], [medusa_sf], [])
            medusa_config = Path(model_id) / "config.json"
            with open(medusa_config, "r") as f:
                config = json.load(f)

            model_id = config["base_model_name_or_path"]
            revision = "main"
            try:
                utils.weight_files(model_id, revision, extension)
                logger.info(
                    f"Files for parent {model_id} are already present on the host. "
                    "Skipping download."
                )
                return
            # Local files not found
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass
            
230
    elif (Path(model_id) / "adapter_config.json").exists():
231
232
233
234
235
236
237
238
239
240
        # Try to load as a local PEFT model
        try:
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
            utils.weight_files(model_id, revision, extension)
            return
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

241
    # Try to see if there are local pytorch weights
242
    try:
243
244
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
245

246
247
248
249
250
251
252
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
253

254
        # Try to see if there are pytorch weights on the hub
255
256
257
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
258
259
260
261
262
263
264
265

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
266
267
268
269
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
270
271
        try:
            import transformers
272
            import json
273

274
275
276
            if is_local_model:
                config_filename = os.path.join(model_id, "config.json")
            else:
OlivierDehaene's avatar
OlivierDehaene committed
277
278
279
                config_filename = hf_hub_download(
                    model_id, revision=revision, filename="config.json"
                )
280
281
282
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]
283
284
285
286
287
288
289
290

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])

        except Exception as e:
            discard_names = []
291
        # Convert pytorch weights to safetensors
292
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
293
294


295
296
297
298
299
300
301
302
303
304
305
306
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
307
308
    if revision is None:
        revision = "main"
309
310
311
312
313
314
315
316
317
318
319
320
321
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.utils.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
322
        revision=revision,
323
324
325
326
327
328
329
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
330
331
if __name__ == "__main__":
    app()