__init__.py 9.99 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

8
9
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
10
from text_generation_server.models.flash_causal_lm import FlashCausalLM
11
12
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
13
from text_generation_server.models.rw import RW
14
from text_generation_server.models.opt import OPT, OPTSharded
15
16
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
17
from text_generation_server.models.gpt_neox import GPTNeoxSharded
18
from text_generation_server.models.t5 import T5Sharded
19

20
try:
21
22
23
24
25
26
27
28
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
29
30
31
32
33
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

        from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
34
        from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
35
36
37
38
39
40
41
42
43
        from text_generation_server.models.flash_llama import (
            FlashLlama,
            FlashLlamaSharded,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoder,
            FlashSantacoderSharded,
        )

44
45
46
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
47
except ImportError:
48
49
50
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
51
    FLASH_ATTENTION = False
52

53
54
55
56
57
__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
58
    "FlashCausalLM",
59
60
61
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
62
63
    "Seq2SeqLM",
    "SantaCoder",
64
65
    "OPT",
    "OPTSharded",
66
    "T5Sharded",
67
68
69
    "get_model",
]

70
if FLASH_ATTENTION:
71
72
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
73
74
    __all__.append(FlashRW)
    __all__.append(FlashRWSharded)
75
    __all__.append(FlashSantacoder)
76
    __all__.append(FlashSantacoderSharded)
77
78
79
    __all__.append(FlashLlama)
    __all__.append(FlashLlamaSharded)

80
81
82
83
84
FLASH_ATT_ERROR_MESSAGE = (
    "{} requires Flash Attention CUDA kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)
85

86
87
88
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
89

90
91
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
92

93
94
95
# Disable gradients
torch.set_grad_enabled(False)

96

97
def get_model(
98
99
100
101
102
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    trust_remote_code: bool,
103
) -> Model:
104
    if "facebook/galactica" in model_id:
105
        if sharded:
106
107
108
109
110
111
            return GalacticaSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
112
        else:
113
114
115
116
117
118
            return Galactica(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
119

120
    if model_id.startswith("bigcode/"):
121
        if sharded:
122
123
124
125
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
126
127
128
129
130
131
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
132
133
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
134
135
136
137
138
139
            return santacoder_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
140

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
141
142
143
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
144
    model_type = config_dict["model_type"]
145

146
147
148
149
150
151
    if model_type == "gpt_bigcode":
        if sharded:
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
152
153
154
155
156
157
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
158
159
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
160
161
162
163
164
165
            return santacoder_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
166

167
    if model_type == "bloom":
168
        if sharded:
169
170
171
172
173
174
            return BLOOMSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
175
        else:
176
177
178
179
180
181
            return BLOOM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
182

183
    if model_type == "gpt_neox":
184
        if sharded:
185
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
186
187
188
189
190
191
            return neox_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
192
        else:
193
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
194
195
196
197
198
199
            return neox_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
200

201
202
203
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
204
205
206
                if config_dict.get("alibi", False) or (
                    model_type == "RefinedWebModel"
                    and config_dict.get("multi_query", True)
207
208
209
210
211
212
213
214
215
216
217
218
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
219
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
                return FlashRW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )

234
235
236
    if model_type == "llama":
        if sharded:
            if FLASH_ATTENTION:
237
238
239
240
241
242
                return FlashLlamaSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
243
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
244
245
        else:
            llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
246
247
248
249
250
251
            return llama_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
252

253
    if model_type == "opt":
254
        if sharded:
255
256
257
258
259
260
            return OPTSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
261
        else:
262
263
264
265
266
267
            return OPT(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
268

269
    if model_type == "t5":
270
        if sharded:
271
272
273
274
275
276
            return T5Sharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
277
        else:
278
279
280
281
282
283
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
284
285
286

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
287
288

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
289
290
291
        return CausalLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
292
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
293
294
295
296
        return Seq2SeqLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

297
    auto_map = config_dict.get("auto_map", None)
298
299
300
301
302
303
304
305
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
306
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
307
308
309
310
311
312
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
313
314

    raise ValueError(f"Unsupported model type {model_type}")