__init__.py 12.4 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
61
62
63

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
64
    FLASH_ATTENTION = False
65
    HAS_FLASH_ATTN_V2_CUDA = False
66

67
if FLASH_ATTENTION:
68
    __all__.append(FlashNeoXSharded)
69
    __all__.append(FlashRWSharded)
70
    __all__.append(FlashSantacoderSharded)
71
    __all__.append(FlashLlama)
72
    __all__.append(IDEFICSSharded)
73
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
74
75
76
    __all__.append(FlashMixtral)


77
def get_model(
78
79
80
81
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
82
    speculate: Optional[int],
83
    dtype: Optional[str],
84
    trust_remote_code: bool,
85
) -> Model:
86
    if dtype is None:
87
88
89
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
90
91
92
93
94
95
96
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
97
98
99
100
101
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

102
    if "facebook/galactica" in model_id:
103
        return GalacticaSharded(
104
105
106
107
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
108
            trust_remote_code=trust_remote_code,
109
        )
110

111
    if model_id.startswith("bigcode/"):
112
        if FLASH_ATTENTION:
113
114
115
116
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
117
                dtype=dtype,
118
119
                trust_remote_code=trust_remote_code,
            )
120
121
122
123
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
124
        else:
125
            return SantaCoder(
126
127
128
                model_id,
                revision,
                quantize=quantize,
129
                dtype=dtype,
130
131
                trust_remote_code=trust_remote_code,
            )
132

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
133
134
135
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
136
137
138
139
140
141
142
143
144

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
145
146
147
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

164
    model_type = config_dict["model_type"]
165

166
    if model_type == "gpt_bigcode":
167
        if FLASH_ATTENTION:
168
169
170
171
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
172
                dtype=dtype,
173
174
                trust_remote_code=trust_remote_code,
            )
175
176
177
178
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
179
        else:
180
            return SantaCoder(
181
182
183
                model_id,
                revision,
                quantize=quantize,
184
                dtype=dtype,
185
186
                trust_remote_code=trust_remote_code,
            )
187

188
    if model_type == "bloom":
189
        return BLOOMSharded(
190
191
192
193
194
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
195
        )
196
197
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
198
199
200
201
202
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
203
        )
204
205
206
207
208
209
210

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
211
                dtype=dtype,
212
213
214
215
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
216
217
218
                model_id,
                revision,
                quantize=quantize,
219
                dtype=dtype,
220
221
                trust_remote_code=trust_remote_code,
            )
222
        else:
223
            return CausalLM(
224
225
226
                model_id,
                revision,
                quantize=quantize,
227
                dtype=dtype,
228
229
                trust_remote_code=trust_remote_code,
            )
230

xiaobin's avatar
xiaobin committed
231
    elif model_type == "llama" or model_type == "baichuan":
232
233
        if FLASH_ATTENTION:
            return FlashLlama(
234
235
236
                model_id,
                revision,
                quantize=quantize,
237
                dtype=dtype,
238
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
239
                use_medusa=use_medusa,
240
            )
241
242
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
243
        else:
244
            return CausalLM(
245
246
247
                model_id,
                revision,
                quantize=quantize,
248
                dtype=dtype,
249
250
                trust_remote_code=trust_remote_code,
            )
251

252
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
253
254
        if sharded:
            if FLASH_ATTENTION:
255
                if config_dict.get("alibi", False):
256
257
258
259
260
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
261
                    dtype=dtype,
262
263
                    trust_remote_code=trust_remote_code,
                )
264
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
265
        else:
266
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
267
                return FlashRWSharded(
268
269
270
                    model_id,
                    revision,
                    quantize=quantize,
271
                    dtype=dtype,
272
273
274
275
276
277
278
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
279
                    dtype=dtype,
280
281
282
                    trust_remote_code=trust_remote_code,
                )

283
    if model_type == "mistral":
284
285
286
287
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
288
289
290
291
292
293
294
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
295
296

    if model_type == "mixtral":
297
298
299
300
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
301
302
303
304
305
306
307
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
308
309

    if model_type == "opt":
310
        return OPTSharded(
311
312
313
314
315
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
316
        )
317

318
    if model_type == "t5":
319
320
321
322
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
323
            dtype=dtype,
324
325
            trust_remote_code=trust_remote_code,
        )
326
    if model_type == "idefics":
327
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
328
329
330
331
332
333
334
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
335
336
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
337
338

    if sharded:
339
        raise NotImplementedError("sharded is not supported for AutoModel")
340
    if quantize == "gptq":
341
        raise NotImplementedError(
342
343
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
344
    if quantize == "awq":
345
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
346
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
347
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
348
    elif quantize == "eetq":
349
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
350
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
351
        return CausalLM(
352
353
354
355
356
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
357
        )
358
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
359
        return Seq2SeqLM(
360
361
362
363
364
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
365
366
        )

367
    auto_map = config_dict.get("auto_map", None)
368
369
370
371
372
373
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
374
                dtype=dtype,
375
376
                trust_remote_code=trust_remote_code,
            )
377
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
378
379
380
381
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
382
                dtype=dtype,
383
384
                trust_remote_code=trust_remote_code,
            )
385
386

    raise ValueError(f"Unsupported model type {model_type}")