__init__.py 10.3 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

FLASH_ATT_ERROR_MESSAGE = (
    "{} requires CUDA and Flash Attention kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)

51
try:
52
53
54
55
56
57
58
    if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
        if not torch.cuda.is_available():
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires CUDA. No compatible CUDA devices found."
            )
            raise ImportError("CUDA is not available")

59
60
61
62
63
64
65
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
66
67
68
69
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
                "No compatible CUDA device found."
            )
70
71
72
73
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

74
75
        from text_generation_server.models.flash_rw import FlashRWSharded
        from text_generation_server.models.flash_neox import FlashNeoXSharded
76
77
78
79
80
81
82
        from text_generation_server.models.flash_llama import (
            FlashLlama,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoderSharded,
        )

83
84
85
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
86
except ImportError:
87
88
89
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
90
    FLASH_ATTENTION = False
91

92
if FLASH_ATTENTION:
93
    __all__.append(FlashNeoXSharded)
94
    __all__.append(FlashRWSharded)
95
    __all__.append(FlashSantacoderSharded)
96
97
    __all__.append(FlashLlama)

98

99
def get_model(
100
101
102
103
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
104
    dtype: Optional[str],
105
    trust_remote_code: bool,
106
) -> Model:
107
108
109
110
111
112
113
114
115
    if dtype is None:
        dtype = torch.float16
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

116
    if "facebook/galactica" in model_id:
117
        return GalacticaSharded(
118
119
120
121
122
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            dtypetrust_remote_code=trust_remote_code,
123
        )
124

125
    if model_id.startswith("bigcode/"):
126
        if FLASH_ATTENTION:
127
128
129
130
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
131
                dtype=dtype,
132
133
                trust_remote_code=trust_remote_code,
            )
134
135
136
137
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
138
        else:
139
            return SantaCoder(
140
141
142
                model_id,
                revision,
                quantize=quantize,
143
                dtype=dtype,
144
145
                trust_remote_code=trust_remote_code,
            )
146

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
147
148
149
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
150
    model_type = config_dict["model_type"]
151

152
    if model_type == "gpt_bigcode":
153
        if FLASH_ATTENTION:
154
155
156
157
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
158
                dtype=dtype,
159
160
                trust_remote_code=trust_remote_code,
            )
161
162
163
164
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
165
        else:
166
            return SantaCoder(
167
168
169
                model_id,
                revision,
                quantize=quantize,
170
                dtype=dtype,
171
172
                trust_remote_code=trust_remote_code,
            )
173

174
    if model_type == "bloom":
175
        return BLOOMSharded(
176
177
178
179
180
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
181
        )
182
183
184
185
    elif model_type == "mpt":
        return MPTSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
186
187
188
189
190
191
192

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
193
                dtype=dtype,
194
195
196
197
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
198
199
200
                model_id,
                revision,
                quantize=quantize,
201
                dtype=dtype,
202
203
                trust_remote_code=trust_remote_code,
            )
204
        else:
205
            return CausalLM(
206
207
208
                model_id,
                revision,
                quantize=quantize,
209
                dtype=dtype,
210
211
                trust_remote_code=trust_remote_code,
            )
212

213
214
215
    elif model_type == "llama":
        if FLASH_ATTENTION:
            return FlashLlama(
216
217
218
                model_id,
                revision,
                quantize=quantize,
219
                dtype=dtype,
220
221
                trust_remote_code=trust_remote_code,
            )
222
223
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
224
        else:
225
            return CausalLM(
226
227
228
                model_id,
                revision,
                quantize=quantize,
229
                dtype=dtype,
230
231
                trust_remote_code=trust_remote_code,
            )
232

233
234
235
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
236
237
238
                if config_dict.get("alibi", False) or (
                    model_type == "RefinedWebModel"
                    and config_dict.get("multi_query", True)
239
240
241
242
243
244
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
245
                    dtype=dtype,
246
247
248
249
250
251
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
252
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
253
                return FlashRWSharded(
254
255
256
                    model_id,
                    revision,
                    quantize=quantize,
257
                    dtype=dtype,
258
259
260
261
262
263
264
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
265
                    dtype=dtype,
266
267
268
                    trust_remote_code=trust_remote_code,
                )

269
270
    elif model_type == "opt":
        return OPTSharded(
271
272
273
274
275
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
276
        )
277

278
    elif model_type == "t5":
279
280
281
282
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
283
            dtype=dtype,
284
285
            trust_remote_code=trust_remote_code,
        )
286
287
288

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
289
290
291
292
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
293
294

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
295
        return CausalLM(
296
297
298
299
300
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
301
        )
302
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
303
        return Seq2SeqLM(
304
305
306
307
308
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
309
310
        )

311
    auto_map = config_dict.get("auto_map", None)
312
313
314
315
316
317
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
318
                dtype=dtype,
319
320
                trust_remote_code=trust_remote_code,
            )
321
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
322
323
324
325
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
326
                dtype=dtype,
327
328
                trust_remote_code=trust_remote_code,
            )
329
330

    raise ValueError(f"Unsupported model type {model_type}")