__init__.py 10.1 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
14
from text_generation_server.models.rw import RW
15
16
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
17
18
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
19
from text_generation_server.models.gpt_neox import GPTNeoxSharded
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

FLASH_ATT_ERROR_MESSAGE = (
    "{} requires CUDA and Flash Attention kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)

50
try:
51
52
53
54
55
56
57
    if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
        if not torch.cuda.is_available():
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires CUDA. No compatible CUDA devices found."
            )
            raise ImportError("CUDA is not available")

58
59
60
61
62
63
64
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
65
66
67
68
            FLASH_ATT_ERROR_MESSAGE = (
                "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
                "No compatible CUDA device found."
            )
69
70
71
72
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

73
74
        from text_generation_server.models.flash_rw import FlashRWSharded
        from text_generation_server.models.flash_neox import FlashNeoXSharded
75
76
77
78
79
80
81
        from text_generation_server.models.flash_llama import (
            FlashLlama,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoderSharded,
        )

82
83
84
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
85
except ImportError:
86
87
88
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
89
    FLASH_ATTENTION = False
90

91
if FLASH_ATTENTION:
92
    __all__.append(FlashNeoXSharded)
93
    __all__.append(FlashRWSharded)
94
    __all__.append(FlashSantacoderSharded)
95
96
    __all__.append(FlashLlama)

97

98
def get_model(
99
100
101
102
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
103
    dtype: Optional[str],
104
    trust_remote_code: bool,
105
) -> Model:
106
107
108
109
110
111
112
113
114
    if dtype is None:
        dtype = torch.float16
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

115
    if "facebook/galactica" in model_id:
116
        return GalacticaSharded(
117
118
119
120
121
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            dtypetrust_remote_code=trust_remote_code,
122
        )
123

124
    if model_id.startswith("bigcode/"):
125
        if FLASH_ATTENTION:
126
127
128
129
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
130
                dtype=dtype,
131
132
                trust_remote_code=trust_remote_code,
            )
133
134
135
136
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
137
        else:
138
            return SantaCoder(
139
140
141
                model_id,
                revision,
                quantize=quantize,
142
                dtype=dtype,
143
144
                trust_remote_code=trust_remote_code,
            )
145

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
146
147
148
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
149
    model_type = config_dict["model_type"]
150

151
    if model_type == "gpt_bigcode":
152
        if FLASH_ATTENTION:
153
154
155
156
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
157
                dtype=dtype,
158
159
                trust_remote_code=trust_remote_code,
            )
160
161
162
163
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
164
        else:
165
            return SantaCoder(
166
167
168
                model_id,
                revision,
                quantize=quantize,
169
                dtype=dtype,
170
171
                trust_remote_code=trust_remote_code,
            )
172

173
    if model_type == "bloom":
174
        return BLOOMSharded(
175
176
177
178
179
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
180
181
182
183
184
185
186
187
        )

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
188
                dtype=dtype,
189
190
191
192
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
193
194
195
                model_id,
                revision,
                quantize=quantize,
196
                dtype=dtype,
197
198
                trust_remote_code=trust_remote_code,
            )
199
        else:
200
            return CausalLM(
201
202
203
                model_id,
                revision,
                quantize=quantize,
204
                dtype=dtype,
205
206
                trust_remote_code=trust_remote_code,
            )
207

208
209
210
    elif model_type == "llama":
        if FLASH_ATTENTION:
            return FlashLlama(
211
212
213
                model_id,
                revision,
                quantize=quantize,
214
                dtype=dtype,
215
216
                trust_remote_code=trust_remote_code,
            )
217
218
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
219
        else:
220
            return CausalLM(
221
222
223
                model_id,
                revision,
                quantize=quantize,
224
                dtype=dtype,
225
226
                trust_remote_code=trust_remote_code,
            )
227

228
229
230
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
231
232
233
                if config_dict.get("alibi", False) or (
                    model_type == "RefinedWebModel"
                    and config_dict.get("multi_query", True)
234
235
236
237
238
239
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
240
                    dtype=dtype,
241
242
243
244
245
246
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
247
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
248
                return FlashRWSharded(
249
250
251
                    model_id,
                    revision,
                    quantize=quantize,
252
                    dtype=dtype,
253
254
255
256
257
258
259
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
260
                    dtype=dtype,
261
262
263
                    trust_remote_code=trust_remote_code,
                )

264
265
    elif model_type == "opt":
        return OPTSharded(
266
267
268
269
270
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
271
        )
272

273
    elif model_type == "t5":
274
275
276
277
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
278
            dtype=dtype,
279
280
            trust_remote_code=trust_remote_code,
        )
281
282
283

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
284
285
286
287
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
288
289

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
290
        return CausalLM(
291
292
293
294
295
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
296
        )
297
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
298
        return Seq2SeqLM(
299
300
301
302
303
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
304
305
        )

306
    auto_map = config_dict.get("auto_map", None)
307
308
309
310
311
312
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
313
                dtype=dtype,
314
315
                trust_remote_code=trust_remote_code,
            )
316
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
317
318
319
320
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
321
                dtype=dtype,
322
323
                trust_remote_code=trust_remote_code,
            )
324
325

    raise ValueError(f"Unsupported model type {model_type}")