__init__.py 16.5 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
57
58
59
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
60
61
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
62
    )
63
    from text_generation_server.models.idefics import IDEFICSSharded
64
65
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
66
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
67
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
68
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
69
70
71

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
72
    FLASH_ATTENTION = False
73
    HAS_FLASH_ATTN_V2_CUDA = False
74

75
if FLASH_ATTENTION:
76
    __all__.append(FlashNeoXSharded)
77
    __all__.append(FlashRWSharded)
78
    __all__.append(FlashSantacoderSharded)
79
    __all__.append(FlashLlama)
80
    __all__.append(IDEFICSSharded)
81
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
82
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
83
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
84
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
85

drbh's avatar
drbh committed
86
87
88
89
90
91
92
93
94
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
95

96

97
def get_model(
98
99
100
101
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
102
    speculate: Optional[int],
103
    dtype: Optional[str],
104
    trust_remote_code: bool,
105
) -> Model:
106
    if dtype is None:
107
108
109
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
110
111
112
113
114
115
116
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
117
118
119
120
121
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
122
123
124
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
125
126
127

    use_medusa = None
    if "medusa_num_heads" in config_dict:
128
129
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
130
131
132
133
134
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
135
136
137
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
138
139
140
141
142
143
144
145
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
160
161
162
163
164
165
166
167
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
184
            use_medusa=use_medusa,
drbh's avatar
drbh committed
185
186
187
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
188

OlivierDehaene's avatar
OlivierDehaene committed
189
190
191
192
193
194
195
196
197
198
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

199
200
201
202
203
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
204
        if FLASH_ATTENTION:
205
206
207
208
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
209
                use_medusa=use_medusa,
210
                dtype=dtype,
211
212
                trust_remote_code=trust_remote_code,
            )
213
214
215
216
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
217
        else:
218
            return SantaCoder(
219
220
221
                model_id,
                revision,
                quantize=quantize,
222
                use_medusa=use_medusa,
223
                dtype=dtype,
224
225
                trust_remote_code=trust_remote_code,
            )
226

227
    if model_type == "bloom":
228
        return BLOOMSharded(
229
230
231
            model_id,
            revision,
            quantize=quantize,
232
            use_medusa=use_medusa,
233
234
            dtype=dtype,
            trust_remote_code=trust_remote_code,
235
        )
236
237
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
238
239
240
            model_id,
            revision,
            quantize=quantize,
241
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
242
243
            dtype=dtype,
            trust_remote_code=trust_remote_code,
244
        )
245
246
247
248
249
250
251

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
252
                use_medusa=use_medusa,
253
                dtype=dtype,
254
255
256
257
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
258
259
260
                model_id,
                revision,
                quantize=quantize,
261
                use_medusa=use_medusa,
262
                dtype=dtype,
263
264
                trust_remote_code=trust_remote_code,
            )
265
        else:
266
            return CausalLM(
267
268
269
                model_id,
                revision,
                quantize=quantize,
270
                use_medusa=use_medusa,
271
                dtype=dtype,
272
273
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
274

drbh's avatar
drbh committed
275
276
277
278
279
280
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
281
                use_medusa=use_medusa,
drbh's avatar
drbh committed
282
283
284
285
286
287
288
289
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
290
                use_medusa=use_medusa,
drbh's avatar
drbh committed
291
292
293
294
295
296
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
297
298
299
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
300
301
302
303
304
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
305
                use_medusa=use_medusa,
drbh's avatar
drbh committed
306
307
308
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
309

xiaobin's avatar
xiaobin committed
310
    elif model_type == "llama" or model_type == "baichuan":
311
312
        if FLASH_ATTENTION:
            return FlashLlama(
313
314
315
                model_id,
                revision,
                quantize=quantize,
316
                use_medusa=use_medusa,
317
                dtype=dtype,
318
319
                trust_remote_code=trust_remote_code,
            )
320
321
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
322
        else:
323
            return CausalLM(
324
325
326
                model_id,
                revision,
                quantize=quantize,
327
                use_medusa=use_medusa,
328
                dtype=dtype,
329
330
                trust_remote_code=trust_remote_code,
            )
331
332
333
334
335
336
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
337
                use_medusa=use_medusa,
338
339
340
341
342
343
344
345
346
347
348
349
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
350
                use_medusa=use_medusa,
351
352
353
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
354

355
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
356
357
        if sharded:
            if FLASH_ATTENTION:
358
                if config_dict.get("alibi", False):
359
360
361
362
363
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
364
                    use_medusa=use_medusa,
365
                    dtype=dtype,
366
367
                    trust_remote_code=trust_remote_code,
                )
368
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
369
        else:
370
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
371
                return FlashRWSharded(
372
373
374
                    model_id,
                    revision,
                    quantize=quantize,
375
                    use_medusa=use_medusa,
376
                    dtype=dtype,
377
378
379
380
381
382
383
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
384
                    use_medusa=use_medusa,
385
                    dtype=dtype,
386
387
388
                    trust_remote_code=trust_remote_code,
                )

389
    if model_type == "mistral":
390
391
392
393
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
394
395
396
397
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
398
                use_medusa=use_medusa,
399
400
401
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
402
403

    if model_type == "mixtral":
404
405
406
407
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
408
409
410
411
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
412
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
413
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
414
415
416
417
418
419
420
421
422
423
424
425
                trust_remote_code=trust_remote_code,
            )
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
426
427
                trust_remote_code=trust_remote_code,
            )
428
429

    if model_type == "opt":
430
        return OPTSharded(
431
432
433
            model_id,
            revision,
            quantize=quantize,
434
            use_medusa=use_medusa,
435
436
            dtype=dtype,
            trust_remote_code=trust_remote_code,
437
        )
438

439
    if model_type == "t5":
440
441
442
443
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
444
            use_medusa=use_medusa,
445
            dtype=dtype,
446
447
            trust_remote_code=trust_remote_code,
        )
448
    if model_type == "idefics":
449
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
450
451
452
453
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
454
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
455
456
457
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
458
459
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
460
461

    if sharded:
462
        raise NotImplementedError("sharded is not supported for AutoModel")
463
    if quantize == "gptq":
464
        raise NotImplementedError(
465
466
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
467
    if quantize == "awq":
468
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
469
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
470
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
471
    elif quantize == "eetq":
472
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
473
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
474
        return CausalLM(
475
476
477
            model_id,
            revision,
            quantize=quantize,
478
            use_medusa=use_medusa,
479
480
            dtype=dtype,
            trust_remote_code=trust_remote_code,
481
        )
482
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
483
        return Seq2SeqLM(
484
485
486
            model_id,
            revision,
            quantize=quantize,
487
            use_medusa=use_medusa,
488
489
            dtype=dtype,
            trust_remote_code=trust_remote_code,
490
491
        )

492
    auto_map = config_dict.get("auto_map", None)
493
494
495
496
497
498
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
499
                use_medusa=use_medusa,
500
                dtype=dtype,
501
502
                trust_remote_code=trust_remote_code,
            )
503
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
504
505
506
507
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
508
                use_medusa=use_medusa,
509
                dtype=dtype,
510
511
                trust_remote_code=trust_remote_code,
            )
512
513

    raise ValueError(f"Unsupported model type {model_type}")