__init__.py 9.91 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers import AutoConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

8
9
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
10
from text_generation_server.models.flash_causal_lm import FlashCausalLM
11
12
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
13
from text_generation_server.models.rw import RW
14
from text_generation_server.models.opt import OPT, OPTSharded
15
16
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
17
from text_generation_server.models.gpt_neox import GPTNeoxSharded
18
from text_generation_server.models.t5 import T5Sharded
19

20
try:
21
22
23
24
25
26
27
28
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
29
30
31
32
33
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

        from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
34
        from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
35
36
37
38
39
40
41
42
43
        from text_generation_server.models.flash_llama import (
            FlashLlama,
            FlashLlamaSharded,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoder,
            FlashSantacoderSharded,
        )

44
45
46
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
47
except ImportError:
48
49
50
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
51
    FLASH_ATTENTION = False
52

53
54
55
56
57
__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
58
    "FlashCausalLM",
59
60
61
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
62
63
    "Seq2SeqLM",
    "SantaCoder",
64
65
    "OPT",
    "OPTSharded",
66
    "T5Sharded",
67
68
69
    "get_model",
]

70
if FLASH_ATTENTION:
71
72
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
73
74
    __all__.append(FlashRW)
    __all__.append(FlashRWSharded)
75
    __all__.append(FlashSantacoder)
76
    __all__.append(FlashSantacoderSharded)
77
78
79
    __all__.append(FlashLlama)
    __all__.append(FlashLlamaSharded)

80
81
82
83
84
FLASH_ATT_ERROR_MESSAGE = (
    "{} requires Flash Attention CUDA kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)
85

86
87
88
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
89

90
91
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
92

93
94
95
# Disable gradients
torch.set_grad_enabled(False)

96

97
def get_model(
98
99
100
101
102
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    trust_remote_code: bool,
103
) -> Model:
104
    if "facebook/galactica" in model_id:
105
        if sharded:
106
107
108
109
110
111
            return GalacticaSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
112
        else:
113
114
115
116
117
118
            return Galactica(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
119

120
    if model_id.startswith("bigcode/"):
121
        if sharded:
122
123
124
125
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
126
127
128
129
130
131
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
132
133
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
134
135
136
137
138
139
            return santacoder_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
140

141
142
143
    config = AutoConfig.from_pretrained(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
144
    model_type = config.model_type
145

146
147
148
149
150
151
    if model_type == "gpt_bigcode":
        if sharded:
            if not FLASH_ATTENTION:
                raise NotImplementedError(
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
                )
152
153
154
155
156
157
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
158
159
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
160
161
162
163
164
165
            return santacoder_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
166

167
    if model_type == "bloom":
168
        if sharded:
169
170
171
172
173
174
            return BLOOMSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
175
        else:
176
177
178
179
180
181
            return BLOOM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
182

183
    if model_type == "gpt_neox":
184
        if sharded:
185
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
186
187
188
189
190
191
            return neox_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
192
        else:
193
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
194
195
196
197
198
199
            return neox_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
                if config.alibi or (
                    config.model_type == "RefinedWebModel"
                    and config.n_head_kv != config.n_head
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
            if FLASH_ATTENTION and not config.alibi:
                return FlashRW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )

234
235
236
    if model_type == "llama":
        if sharded:
            if FLASH_ATTENTION:
237
238
239
240
241
242
                return FlashLlamaSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
243
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
244
245
        else:
            llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
246
247
248
249
250
251
            return llama_cls(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
252

253
254
    if config.model_type == "opt":
        if sharded:
255
256
257
258
259
260
            return OPTSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
261
        else:
262
263
264
265
266
267
            return OPT(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
268

269
    if model_type == "t5":
270
        if sharded:
271
272
273
274
275
276
            return T5Sharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
277
        else:
278
279
280
281
282
283
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
284
285
286

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
287
288

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
289
290
291
        return CausalLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
292
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
293
294
295
296
297
298
299
300
301
302
303
304
305
        return Seq2SeqLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

    auto_map = getattr(config, "auto_map", None)
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
306
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
307
308
309
310
311
312
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
313
314

    raise ValueError(f"Unsupported model type {model_type}")