__init__.py 18.5 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
60
61
62
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
63
64
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
65
    )
66
    from text_generation_server.models.idefics import IDEFICSSharded
67
68
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
69
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
70
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
71
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
72
73
74

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
75
    FLASH_ATTENTION = False
76
    HAS_FLASH_ATTN_V2_CUDA = False
77

78
if FLASH_ATTENTION:
79
    __all__.append(FlashNeoXSharded)
80
    __all__.append(FlashRWSharded)
81
    __all__.append(FlashSantacoderSharded)
82
    __all__.append(FlashLlama)
83
    __all__.append(IDEFICSSharded)
84
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
85
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
86
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
87
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
88
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
89

drbh's avatar
drbh committed
90
91
92
93
94
95
96
97
98
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
99

100

101
def get_model(
102
103
104
105
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
106
    speculate: Optional[int],
107
    dtype: Optional[str],
108
    trust_remote_code: bool,
109
) -> Model:
110
    if dtype is None:
111
112
113
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
114
115
116
117
118
119
120
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
121
122
123
124
125
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
126
127
128
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
129
130
131

    use_medusa = None
    if "medusa_num_heads" in config_dict:
132
133
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
134
135
136
137
138
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
139
140
141
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
142
143
144
145
146
147
148
149
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
164
165
166
167
168
169
170
171
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
188
            use_medusa=use_medusa,
drbh's avatar
drbh committed
189
190
191
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
192

OlivierDehaene's avatar
OlivierDehaene committed
193
194
195
196
197
198
199
200
201
202
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

203
204
205
206
207
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
208
        if FLASH_ATTENTION:
209
210
211
212
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
213
                use_medusa=use_medusa,
214
                dtype=dtype,
215
216
                trust_remote_code=trust_remote_code,
            )
217
218
219
220
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
221
        else:
222
            return SantaCoder(
223
224
225
                model_id,
                revision,
                quantize=quantize,
226
                use_medusa=use_medusa,
227
                dtype=dtype,
228
229
                trust_remote_code=trust_remote_code,
            )
230

231
    if model_type == "bloom":
232
        return BLOOMSharded(
233
234
235
            model_id,
            revision,
            quantize=quantize,
236
            use_medusa=use_medusa,
237
238
            dtype=dtype,
            trust_remote_code=trust_remote_code,
239
        )
240
241
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
242
243
244
            model_id,
            revision,
            quantize=quantize,
245
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
246
247
            dtype=dtype,
            trust_remote_code=trust_remote_code,
248
        )
249
250
251
252
253
254
255

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
256
                use_medusa=use_medusa,
257
                dtype=dtype,
258
259
260
261
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
262
263
264
                model_id,
                revision,
                quantize=quantize,
265
                use_medusa=use_medusa,
266
                dtype=dtype,
267
268
                trust_remote_code=trust_remote_code,
            )
269
        else:
270
            return CausalLM(
271
272
273
                model_id,
                revision,
                quantize=quantize,
274
                use_medusa=use_medusa,
275
                dtype=dtype,
276
277
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
278

drbh's avatar
drbh committed
279
280
281
282
283
284
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
285
                use_medusa=use_medusa,
drbh's avatar
drbh committed
286
287
288
289
290
291
292
293
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
294
                use_medusa=use_medusa,
drbh's avatar
drbh committed
295
296
297
298
299
300
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
301
302
303
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
304
305
306
307
308
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
309
                use_medusa=use_medusa,
drbh's avatar
drbh committed
310
311
312
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
313

xiaobin's avatar
xiaobin committed
314
    elif model_type == "llama" or model_type == "baichuan":
315
316
        if FLASH_ATTENTION:
            return FlashLlama(
317
318
319
                model_id,
                revision,
                quantize=quantize,
320
                use_medusa=use_medusa,
321
                dtype=dtype,
322
323
                trust_remote_code=trust_remote_code,
            )
324
325
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
326
        else:
327
            return CausalLM(
328
329
330
                model_id,
                revision,
                quantize=quantize,
331
                use_medusa=use_medusa,
332
                dtype=dtype,
333
334
                trust_remote_code=trust_remote_code,
            )
335
336
337
338
339
340
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
341
                use_medusa=use_medusa,
342
343
344
345
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
346
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
347
348
349
350
351
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
352
                use_medusa=use_medusa,
353
354
355
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
356

357
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
358
359
        if sharded:
            if FLASH_ATTENTION:
360
                if config_dict.get("alibi", False):
361
362
363
364
365
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
366
                    use_medusa=use_medusa,
367
                    dtype=dtype,
368
369
                    trust_remote_code=trust_remote_code,
                )
370
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
371
        else:
372
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
373
                return FlashRWSharded(
374
375
376
                    model_id,
                    revision,
                    quantize=quantize,
377
                    use_medusa=use_medusa,
378
                    dtype=dtype,
379
380
381
382
383
384
385
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
386
                    use_medusa=use_medusa,
387
                    dtype=dtype,
388
389
390
                    trust_remote_code=trust_remote_code,
                )

391
    if model_type == "mistral":
392
393
394
395
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
396
397
398
399
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
400
                use_medusa=use_medusa,
401
402
403
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
404
405
406
407
408
409
410
411
412
413
414
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
415
416

    if model_type == "mixtral":
417
418
419
420
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
421
422
423
424
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
425
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
426
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
427
428
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
429
430
431
432
433
434
435
436
437
438
439
440
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
441
442
443
444
445
446
447
448
449
450
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
488
489
                trust_remote_code=trust_remote_code,
            )
490
491

    if model_type == "opt":
492
        return OPTSharded(
493
494
495
            model_id,
            revision,
            quantize=quantize,
496
            use_medusa=use_medusa,
497
498
            dtype=dtype,
            trust_remote_code=trust_remote_code,
499
        )
500

501
    if model_type == "t5":
502
503
504
505
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
506
            use_medusa=use_medusa,
507
            dtype=dtype,
508
509
            trust_remote_code=trust_remote_code,
        )
510
    if model_type == "idefics":
511
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
512
513
514
515
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
516
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
517
518
519
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
520
521
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
522
523

    if sharded:
524
        raise NotImplementedError("sharded is not supported for AutoModel")
525
    if quantize == "gptq":
526
        raise NotImplementedError(
527
528
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
529
    if quantize == "awq":
530
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
531
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
532
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
533
    elif quantize == "eetq":
534
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
535
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
536
        return CausalLM(
537
538
539
            model_id,
            revision,
            quantize=quantize,
540
            use_medusa=use_medusa,
541
542
            dtype=dtype,
            trust_remote_code=trust_remote_code,
543
        )
544
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
545
        return Seq2SeqLM(
546
547
548
            model_id,
            revision,
            quantize=quantize,
549
            use_medusa=use_medusa,
550
551
            dtype=dtype,
            trust_remote_code=trust_remote_code,
552
553
        )

554
    auto_map = config_dict.get("auto_map", None)
555
556
557
558
559
560
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
561
                use_medusa=use_medusa,
562
                dtype=dtype,
563
564
                trust_remote_code=trust_remote_code,
            )
565
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
566
567
568
569
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
570
                use_medusa=use_medusa,
571
                dtype=dtype,
572
573
                trust_remote_code=trust_remote_code,
            )
574
575

    raise ValueError(f"Unsupported model type {model_type}")