__init__.py 9.11 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
14
from text_generation_server.models.rw import RW
15
16
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
17
18
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
19
from text_generation_server.models.gpt_neox import GPTNeoxSharded
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

FLASH_ATT_ERROR_MESSAGE = (
    "{} requires CUDA and Flash Attention kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)

50
try:
51
52
53
54
55
56
57
    if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
        if not torch.cuda.is_available():
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires CUDA. No compatible CUDA devices found."
            )
            raise ImportError("CUDA is not available")

58
59
60
61
62
63
64
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
65
66
67
68
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
                "No compatible CUDA device found."
            )
69
70
71
72
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

73
74
        from text_generation_server.models.flash_rw import FlashRWSharded
        from text_generation_server.models.flash_neox import FlashNeoXSharded
75
76
77
78
79
80
81
        from text_generation_server.models.flash_llama import (
            FlashLlama,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoderSharded,
        )

82
83
84
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
85
except ImportError:
86
87
88
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
89
    FLASH_ATTENTION = False
90

91
if FLASH_ATTENTION:
92
    __all__.append(FlashNeoXSharded)
93
    __all__.append(FlashRWSharded)
94
    __all__.append(FlashSantacoderSharded)
95
96
    __all__.append(FlashLlama)

97

98
def get_model(
99
100
101
102
103
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    trust_remote_code: bool,
104
) -> Model:
105
    if "facebook/galactica" in model_id:
106
107
108
        return GalacticaSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
109

110
    if model_id.startswith("bigcode/"):
111
        if FLASH_ATTENTION:
112
113
114
115
116
117
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
118
119
120
121
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
122
        else:
123
            return SantaCoder(
124
125
126
127
128
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
129

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
130
131
132
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
133
    model_type = config_dict["model_type"]
134

135
    if model_type == "gpt_bigcode":
136
        if FLASH_ATTENTION:
137
138
139
140
141
142
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
143
144
145
146
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
147
        else:
148
            return SantaCoder(
149
150
151
152
153
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
154

155
    if model_type == "bloom":
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        return BLOOMSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
170
171
172
173
174
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
175
        else:
176
            return CausalLM(
177
178
179
180
181
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
182

183
184
185
    elif model_type == "llama":
        if FLASH_ATTENTION:
            return FlashLlama(
186
187
188
189
190
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
191
192
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
193
        else:
194
            return CausalLM(
195
196
197
198
199
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
200

201
202
203
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
204
205
206
                if config_dict.get("alibi", False) or (
                    model_type == "RefinedWebModel"
                    and config_dict.get("multi_query", True)
207
208
209
210
211
212
213
214
215
216
217
218
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
219
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
220
                return FlashRWSharded(
221
222
223
224
225
226
227
228
229
230
231
232
233
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )

234
235
236
237
    elif model_type == "opt":
        return OPTSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
238

239
    elif model_type == "t5":
240
241
242
243
244
245
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
            trust_remote_code=trust_remote_code,
        )
246
247
248

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
249
250
251
252
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
253
254

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
255
256
257
        return CausalLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
258
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
259
260
261
262
        return Seq2SeqLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

263
    auto_map = config_dict.get("auto_map", None)
264
265
266
267
268
269
270
271
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
272
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
273
274
275
276
277
278
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
279
280

    raise ValueError(f"Unsupported model type {model_type}")