__init__.py 5.42 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers import AutoConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

8
9
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
10
from text_generation_server.models.flash_causal_lm import FlashCausalLM
11
12
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
13
from text_generation_server.models.opt import OPT, OPTSharded
14
15
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
16
from text_generation_server.models.gpt_neox import GPTNeoxSharded
17
from text_generation_server.models.t5 import T5Sharded
18

19
20
try:
    from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
21
    from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
22
23
24
25
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoder,
        FlashSantacoderSharded,
    )
26

27
28
29
30
31
32
33
34
35
36
37
38
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
            raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported")
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
39
except ImportError:
40
41
42
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
43
    FLASH_ATTENTION = False
44

45
46
47
48
49
__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
50
    "FlashCausalLM",
51
52
53
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
54
    "Seq2SeqLM",
55
56
    "Galactica",
    "GalacticaSharded",
57
    "SantaCoder",
58
59
    "OPT",
    "OPTSharded",
60
    "T5Sharded",
61
62
63
    "get_model",
]

64
if FLASH_ATTENTION:
65
66
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
67
    __all__.append(FlashSantacoder)
68
    __all__.append(FlashSantacoderSharded)
69
70
71
    __all__.append(FlashLlama)
    __all__.append(FlashLlamaSharded)

72
73
74
75
76
FLASH_ATT_ERROR_MESSAGE = (
    "{} requires Flash Attention CUDA kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)
77

78
79
80
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
81

82
83
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
84

85
86
87
# Disable gradients
torch.set_grad_enabled(False)

88

89
def get_model(
90
    model_id: str, revision: Optional[str], sharded: bool, quantize: bool
91
) -> Model:
92
    if "facebook/galactica" in model_id:
93
94
95
96
97
        if sharded:
            return GalacticaSharded(model_id, revision, quantize=quantize)
        else:
            return Galactica(model_id, revision, quantize=quantize)

98
    if "bigcode" in model_id:
99
        if sharded:
100
101
102
103
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
104
            return FlashSantacoderSharded(model_id, revision, quantize=quantize)
105
106
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
107
            return santacoder_cls(model_id, revision, quantize=quantize)
108

109
    config = AutoConfig.from_pretrained(model_id, revision=revision)
110
    model_type = config.model_type
111

112
    if model_type == "bloom":
113
        if sharded:
114
            return BLOOMSharded(model_id, revision, quantize=quantize)
115
        else:
116
            return BLOOM(model_id, revision, quantize=quantize)
117

118
    if model_type == "gpt_neox":
119
        if sharded:
120
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
121
            return neox_cls(model_id, revision, quantize=quantize)
122
        else:
123
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
124
            return neox_cls(model_id, revision, quantize=quantize)
125

126
127
128
129
    if model_type == "llama":
        if sharded:
            if FLASH_ATTENTION:
                return FlashLlamaSharded(model_id, revision, quantize=quantize)
130
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
131
132
133
134
        else:
            llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
            return llama_cls(model_id, revision, quantize=quantize)

135
136
137
138
139
140
    if config.model_type == "opt":
        if sharded:
            return OPTSharded(model_id, revision, quantize=quantize)
        else:
            return OPT(model_id, revision, quantize=quantize)

141
    if model_type == "t5":
142
143
144
145
        if sharded:
            return T5Sharded(model_id, revision, quantize=quantize)
        else:
            return Seq2SeqLM(model_id, revision, quantize=quantize)
146
147
148

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
149
150

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
151
        return CausalLM(model_id, revision, quantize=quantize)
152
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
153
        return Seq2SeqLM(model_id, revision, quantize=quantize)
154
155

    raise ValueError(f"Unsupported model type {model_type}")