__init__.py 20.1 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
71
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
72
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
73
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
74
    from text_generation_server.models.flash_dbrx import FlashDbrx
75
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
76
77
78

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
79
    FLASH_ATTENTION = False
80
    HAS_FLASH_ATTN_V2_CUDA = False
81

82
if FLASH_ATTENTION:
83
    __all__.append(FlashNeoXSharded)
84
    __all__.append(FlashRWSharded)
85
    __all__.append(FlashSantacoderSharded)
86
    __all__.append(FlashLlama)
87
    __all__.append(IDEFICSSharded)
88
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
89
    __all__.append(FlashMixtral)
90
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
91
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
92
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
93
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
94
95
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
96

drbh's avatar
drbh committed
97
98
99
100
101
102
103
104
105
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
106

107

108
def get_model(
109
110
111
112
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: Optional[int],
114
    dtype: Optional[str],
115
    trust_remote_code: bool,
116
) -> Model:
117
    if dtype is None:
118
119
120
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
121
122
123
124
125
126
127
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
131
132
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
133
134
135
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
136
137
138

    use_medusa = None
    if "medusa_num_heads" in config_dict:
139
140
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
141
142
143
144
145
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
146
147
148
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
149
150
151
152
153
154
155
156
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
171
172
173
174
175
176
177
178
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
195
            use_medusa=use_medusa,
drbh's avatar
drbh committed
196
197
198
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
199

OlivierDehaene's avatar
OlivierDehaene committed
200
201
202
203
204
205
206
207
208
209
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

210
211
212
213
214
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
215
        if FLASH_ATTENTION:
216
217
218
219
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
220
                use_medusa=use_medusa,
221
                dtype=dtype,
222
223
                trust_remote_code=trust_remote_code,
            )
224
225
226
227
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
228
        else:
229
            return SantaCoder(
230
231
232
                model_id,
                revision,
                quantize=quantize,
233
                use_medusa=use_medusa,
234
                dtype=dtype,
235
236
                trust_remote_code=trust_remote_code,
            )
237

238
    if model_type == "bloom":
239
        return BLOOMSharded(
240
241
242
            model_id,
            revision,
            quantize=quantize,
243
            use_medusa=use_medusa,
244
245
            dtype=dtype,
            trust_remote_code=trust_remote_code,
246
        )
247
248
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
249
250
251
            model_id,
            revision,
            quantize=quantize,
252
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
253
254
            dtype=dtype,
            trust_remote_code=trust_remote_code,
255
        )
256
257
258
259
260
261
262

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
263
                use_medusa=use_medusa,
264
                dtype=dtype,
265
266
267
268
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
269
270
271
                model_id,
                revision,
                quantize=quantize,
272
                use_medusa=use_medusa,
273
                dtype=dtype,
274
275
                trust_remote_code=trust_remote_code,
            )
276
        else:
277
            return CausalLM(
278
279
280
                model_id,
                revision,
                quantize=quantize,
281
                use_medusa=use_medusa,
282
                dtype=dtype,
283
284
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
285

drbh's avatar
drbh committed
286
287
288
289
290
291
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
292
                use_medusa=use_medusa,
drbh's avatar
drbh committed
293
294
295
296
297
298
299
300
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
301
                use_medusa=use_medusa,
drbh's avatar
drbh committed
302
303
304
305
306
307
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
308
309
310
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
311
312
313
314
315
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
316
                use_medusa=use_medusa,
drbh's avatar
drbh committed
317
318
319
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
320

xiaobin's avatar
xiaobin committed
321
    elif model_type == "llama" or model_type == "baichuan":
322
323
        if FLASH_ATTENTION:
            return FlashLlama(
324
325
326
                model_id,
                revision,
                quantize=quantize,
327
                use_medusa=use_medusa,
328
                dtype=dtype,
329
330
                trust_remote_code=trust_remote_code,
            )
331
332
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
333
        else:
334
            return CausalLM(
335
336
337
                model_id,
                revision,
                quantize=quantize,
338
                use_medusa=use_medusa,
339
                dtype=dtype,
340
341
                trust_remote_code=trust_remote_code,
            )
342
343
344
345
346
347
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
348
                use_medusa=use_medusa,
349
350
351
352
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
353
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
354
355
356
357
358
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
359
                use_medusa=use_medusa,
360
361
362
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
363

OlivierDehaene's avatar
OlivierDehaene committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

408
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
409
410
        if sharded:
            if FLASH_ATTENTION:
411
                if config_dict.get("alibi", False):
412
413
414
415
416
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
417
                    use_medusa=use_medusa,
418
                    dtype=dtype,
419
420
                    trust_remote_code=trust_remote_code,
                )
421
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
422
        else:
423
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
424
                return FlashRWSharded(
425
426
427
                    model_id,
                    revision,
                    quantize=quantize,
428
                    use_medusa=use_medusa,
429
                    dtype=dtype,
430
431
432
433
434
435
436
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
437
                    use_medusa=use_medusa,
438
                    dtype=dtype,
439
440
441
                    trust_remote_code=trust_remote_code,
                )

442
    if model_type == "mistral":
443
444
445
446
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
447
448
449
450
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
451
                use_medusa=use_medusa,
452
453
454
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
455
456
457
458
459
460
461
462
463
464
465
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
466
467

    if model_type == "mixtral":
468
469
470
471
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
472
473
474
475
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
476
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
477
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
478
479
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
480
481
482
483
484
485
486
487
488
489
490
491
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
492
493
494
495
496
497
498
499
500
501
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
539
540
                trust_remote_code=trust_remote_code,
            )
541
542

    if model_type == "opt":
543
        return OPTSharded(
544
545
546
            model_id,
            revision,
            quantize=quantize,
547
            use_medusa=use_medusa,
548
549
            dtype=dtype,
            trust_remote_code=trust_remote_code,
550
        )
551

552
    if model_type == "t5":
553
554
555
556
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
557
            use_medusa=use_medusa,
558
            dtype=dtype,
559
560
            trust_remote_code=trust_remote_code,
        )
561
    if model_type == "idefics":
562
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
563
564
565
566
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
567
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
568
569
570
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
571
572
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
573
574

    if sharded:
575
        raise NotImplementedError("sharded is not supported for AutoModel")
576
    if quantize == "gptq":
577
        raise NotImplementedError(
578
579
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
580
    if quantize == "awq":
581
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
582
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
583
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
584
    elif quantize == "eetq":
585
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
586
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
587
        return CausalLM(
588
589
590
            model_id,
            revision,
            quantize=quantize,
591
            use_medusa=use_medusa,
592
593
            dtype=dtype,
            trust_remote_code=trust_remote_code,
594
        )
595
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
596
        return Seq2SeqLM(
597
598
599
            model_id,
            revision,
            quantize=quantize,
600
            use_medusa=use_medusa,
601
602
            dtype=dtype,
            trust_remote_code=trust_remote_code,
603
604
        )

605
    auto_map = config_dict.get("auto_map", None)
606
607
608
609
610
611
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
612
                use_medusa=use_medusa,
613
                dtype=dtype,
614
615
                trust_remote_code=trust_remote_code,
            )
616
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
617
618
619
620
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
621
                use_medusa=use_medusa,
622
                dtype=dtype,
623
624
                trust_remote_code=trust_remote_code,
            )
625
626

    raise ValueError(f"Unsupported model type {model_type}")