cli.py 9.02 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
9
from huggingface_hub import hf_hub_download
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
12
13
14


app = typer.Typer()


15
16
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
Nicolas Patry's avatar
Nicolas Patry committed
17
18
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
19
    gptq = "gptq"
20
    awq = "awq"
21
    eetq = "eetq"
22
23


24
25
26
27
28
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
29
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
def serve(
31
    model_id: str,
32
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    sharded: bool = False,
34
    quantize: Optional[Quantization] = None,
Nicolas Patry's avatar
Nicolas Patry committed
35
    speculate: Optional[int] = None,
36
    dtype: Optional[Dtype] = None,
37
    trust_remote_code: bool = False,
38
    uds_path: Path = "/tmp/text-generation-server",
39
40
    logger_level: str = "INFO",
    json_output: bool = False,
41
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
42
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

57
58
59
60
61
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
62
        filter="text_generation_server",
63
64
65
66
67
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
68
69
70
71
72

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

73
74
75
76
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

77
78
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
79
    dtype = None if dtype is None else dtype.value
80
    if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}:
81
82
83
84
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
Nicolas Patry's avatar
Nicolas Patry committed
85
        model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
86
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
87
88
89


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
90
def download_weights(
91
    model_id: str,
92
    revision: Optional[str] = None,
93
    extension: str = ".safetensors",
94
    auto_convert: bool = True,
95
96
    logger_level: str = "INFO",
    json_output: bool = False,
97
    trust_remote_code: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
98
):
99
100
101
102
103
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
104
        filter="text_generation_server",
105
106
107
108
109
110
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

111
112
113
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

114
115
116
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
117
        logger.info("Files are already present on the host. " "Skipping download.")
118
119
        return
    # Local files not found
Nicolas Patry's avatar
Nicolas Patry committed
120
    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
121
122
        pass

123
124
125
126
127
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
128
        try:
OlivierDehaene's avatar
OlivierDehaene committed
129
130
131
132
133
134
            adapter_config_filename = hf_hub_download(
                model_id, revision=revision, filename="adapter_config.json"
            )
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
135
136
137
            is_local_model = True
            utils.weight_files(model_id, revision, extension)
            return
138
139
140
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

Nicolas Patry's avatar
Nicolas Patry committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        try:
            import json
            medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
            if auto_convert:
                medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
                if not medusa_sf.exists():
                    utils.convert_files([Path(medusa_head)], [medusa_sf], [])
            medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
            with open(medusa_config, "r") as f:
                config = json.load(f)

            model_id = config["base_model_name_or_path"]
            revision = "main"
            try:
                utils.weight_files(model_id, revision, extension)
                logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
                return
            # Local files not found
            except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
                pass
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

164
165
166
167
168
169
170
171
172
173
174
175
176
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

177
178
179
180
181
182
183
184
185
186
187
    else:
        # Try to load as a local PEFT model
        try:
            utils.download_and_unload_peft(
                model_id, revision, trust_remote_code=trust_remote_code
            )
            utils.weight_files(model_id, revision, extension)
            return
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

188
    # Try to see if there are local pytorch weights
189
    try:
190
191
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
192

193
194
195
196
197
198
199
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
200

201
        # Try to see if there are pytorch weights on the hub
202
203
204
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
205
206
207
208
209
210
211
212

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
213
214
215
216
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
217
218
        try:
            import transformers
219
            import json
220

221
222
223
            if is_local_model:
                config_filename = os.path.join(model_id, "config.json")
            else:
OlivierDehaene's avatar
OlivierDehaene committed
224
225
226
                config_filename = hf_hub_download(
                    model_id, revision=revision, filename="config.json"
                )
227
228
229
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]
230
231
232
233
234
235
236
237

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])

        except Exception as e:
            discard_names = []
238
        # Convert pytorch weights to safetensors
239
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
240
241


242
243
244
245
246
247
248
249
250
251
252
253
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
254
255
    if revision is None:
        revision = "main"
256
257
258
259
260
261
262
263
264
265
266
267
268
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.utils.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
269
        revision=revision,
270
271
272
273
274
275
276
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
277
278
if __name__ == "__main__":
    app()