__init__.py 3.7 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers import AutoConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
13
14
15
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
16
from text_generation_server.models.gpt_neox import GPTNeoxSharded
17
from text_generation_server.models.t5 import T5Sharded
18

19
20
try:
    from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
21
    from text_generation_server.models.flash_santacoder import FlashSantacoder
22

23
24
25
    FLASH_ATTENTION = (
        torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1
    )
26
except ImportError:
27
28
29
    if int(os.environ.get("FLASH_ATTENTION", 0)) == 1:
        logger.exception("Could not import Flash Attention models")
    FLASH_ATTENTION = False
30

31
32
33
34
35
__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
36
    "FlashCausalLM",
37
38
39
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
40
41
    "Seq2SeqLM",
    "SantaCoder",
42
    "T5Sharded",
43
44
45
    "get_model",
]

46
if FLASH_ATTENTION:
47
48
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
49
    __all__.append(FlashSantacoder)
50

51
52
53
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
54

55
56
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
57

58
59
60
# Disable gradients
torch.set_grad_enabled(False)

61

62
def get_model(
63
    model_id: str, revision: Optional[str], sharded: bool, quantize: bool
64
) -> Model:
65
    if "facebook/galactica" in model_id:
66
67
68
69
70
71
        if sharded:
            return GalacticaSharded(model_id, revision, quantize=quantize)
        else:
            return Galactica(model_id, revision, quantize=quantize)

    if "santacoder" in model_id:
72
73
74
75
76
        if sharded:
            raise NotImplementedError("sharded is not supported for Santacoder")
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
            return santacoder_cls(model_id, revision, quantize)
77

78
    config = AutoConfig.from_pretrained(model_id, revision=revision)
79
    model_type = config.model_type
80

81
    if model_type == "bloom":
82
        if sharded:
83
            return BLOOMSharded(model_id, revision, quantize=quantize)
84
        else:
85
            return BLOOM(model_id, revision, quantize=quantize)
86

87
    if model_type == "gpt_neox":
88
        if sharded:
89
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
90
            return neox_cls(model_id, revision, quantize=quantize)
91
        else:
92
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
93
            return neox_cls(model_id, revision, quantize=quantize)
94

95
    if model_type == "t5":
96
97
98
99
        if sharded:
            return T5Sharded(model_id, revision, quantize=quantize)
        else:
            return Seq2SeqLM(model_id, revision, quantize=quantize)
100
101
102

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
103
104

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
105
        return CausalLM(model_id, revision, quantize=quantize)
106
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
107
        return Seq2SeqLM(model_id, revision, quantize=quantize)
108
109

    raise ValueError(f"Unsupported model type {model_type}")