cli.py 7.18 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
8
from enum import Enum
9
from huggingface_hub import hf_hub_download
Olivier Dehaene's avatar
Olivier Dehaene committed
10
11
12
13
14


app = typer.Typer()


15
16
class Quantization(str, Enum):
    bitsandbytes = "bitsandbytes"
Nicolas Patry's avatar
Nicolas Patry committed
17
18
    bitsandbytes_nf4 = "bitsandbytes-nf4"
    bitsandbytes_fp4 = "bitsandbytes-fp4"
19
20
21
    gptq = "gptq"


22
23
24
25
26
class Dtype(str, Enum):
    float16 = "float16"
    bloat16 = "bfloat16"


Olivier Dehaene's avatar
Olivier Dehaene committed
27
@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
def serve(
29
    model_id: str,
30
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
    sharded: bool = False,
32
    quantize: Optional[Quantization] = None,
33
    dtype: Optional[Dtype] = None,
34
    trust_remote_code: bool = False,
35
    uds_path: Path = "/tmp/text-generation-server",
36
37
    logger_level: str = "INFO",
    json_output: bool = False,
38
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
39
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

54
55
56
57
58
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
59
        filter="text_generation_server",
60
61
62
63
64
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
65
66
67
68
69

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

70
71
72
73
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

74
75
    # Downgrade enum into str for easier management later on
    quantize = None if quantize is None else quantize.value
76
77
78
79
80
81
82
83
    dtype = None if dtype is None else dtype.value
    if dtype is not None and quantize is not None:
        raise RuntimeError(
            "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
        )
    server.serve(
        model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
    )
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
86


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
87
def download_weights(
88
    model_id: str,
89
    revision: Optional[str] = None,
90
    extension: str = ".safetensors",
91
    auto_convert: bool = True,
92
93
    logger_level: str = "INFO",
    json_output: bool = False,
94
    trust_remote_code: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
95
):
96
97
98
99
100
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
101
        filter="text_generation_server",
102
103
104
105
106
107
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

108
109
110
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

111
112
113
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
114
        logger.info("Files are already present on the host. " "Skipping download.")
115
116
        return
    # Local files not found
117
    except (utils.LocalEntryNotFoundError, FileNotFoundError):
118
119
        pass

120
121
122
123
124
    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
        "WEIGHTS_CACHE_OVERRIDE", None
    ) is not None

    if not is_local_model:
125
126
127
128
129
130
        try:
            adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
            utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            pass

131
132
133
134
135
136
137
138
139
140
141
142
143
144
        # Try to download weights from the hub
        try:
            filenames = utils.weight_hub_files(model_id, revision, extension)
            utils.download_weights(filenames, model_id, revision)
            # Successfully downloaded weights
            return

        # No weights found on the hub with this extension
        except utils.EntryNotFoundError as e:
            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
            if not extension == ".safetensors" or not auto_convert:
                raise e

    # Try to see if there are local pytorch weights
145
    try:
146
147
        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
        local_pt_files = utils.weight_files(model_id, revision, ".bin")
148

149
150
151
152
153
154
155
    # No local pytorch weights
    except utils.LocalEntryNotFoundError:
        if extension == ".safetensors":
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Downloading PyTorch weights."
            )
156

157
        # Try to see if there are pytorch weights on the hub
158
159
160
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
161
162
163
164
165
166
167
168

    if auto_convert:
        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights to safetensors."
        )

        # Safetensors final filenames
169
170
171
172
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
173
174
        try:
            import transformers
175
            import json
176

177
178
179
180
181

            config_filename = hf_hub_download(model_id, revision=revision, filename="config.json")
            with open(config_filename, "r") as f:
                config = json.load(f)
            architecture = config["architectures"][0]
182
183
184
185
186
187
188
189
190

            class_ = getattr(transformers, architecture)

            # Name for this varible depends on transformers version.
            discard_names = getattr(class_, "_tied_weights_keys", [])
            discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

        except Exception as e:
            discard_names = []
191
        # Convert pytorch weights to safetensors
192
        utils.convert_files(local_pt_files, local_st_files, discard_names)
Olivier Dehaene's avatar
Olivier Dehaene committed
193
194


195
196
197
198
199
200
201
202
203
204
205
206
@app.command()
def quantize(
    model_id: str,
    output_dir: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    trust_remote_code: bool = False,
    upload_to_model_id: Optional[str] = None,
    percdamp: float = 0.01,
    act_order: bool = False,
):
207
208
    if revision is None:
        revision = "main"
209
210
211
212
213
214
215
216
217
218
219
220
221
    download_weights(
        model_id=model_id,
        revision=revision,
        logger_level=logger_level,
        json_output=json_output,
    )
    from text_generation_server.utils.gptq.quantize import quantize

    quantize(
        model_id=model_id,
        bits=4,
        groupsize=128,
        output_dir=output_dir,
222
        revision=revision,
223
224
225
226
227
228
229
        trust_remote_code=trust_remote_code,
        upload_to_model_id=upload_to_model_id,
        percdamp=percdamp,
        act_order=act_order,
    )


Olivier Dehaene's avatar
Olivier Dehaene committed
230
231
if __name__ == "__main__":
    app()