__init__.py 15 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
21
from text_generation_server.models.phi import Phi
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

46
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
47

48
FLASH_ATTENTION = True
49
try:
50
51
52
53
54
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
55
56
57
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
58
59
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
60
    )
61
    from text_generation_server.models.idefics import IDEFICSSharded
62
63
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
64
    from text_generation_server.models.flash_phi import FlashPhi
65
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
66
67
68

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
69
    FLASH_ATTENTION = False
70
    HAS_FLASH_ATTN_V2_CUDA = False
71

72
if FLASH_ATTENTION:
73
    __all__.append(FlashNeoXSharded)
74
    __all__.append(FlashRWSharded)
75
    __all__.append(FlashSantacoderSharded)
76
    __all__.append(FlashLlama)
77
    __all__.append(IDEFICSSharded)
78
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
79
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
80
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
81

drbh's avatar
drbh committed
82
83
84
85
86
87
88
89
90
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
91

92

93
def get_model(
94
95
96
97
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
98
    speculate: Optional[int],
99
    dtype: Optional[str],
100
    trust_remote_code: bool,
101
) -> Model:
102
    if dtype is None:
103
104
105
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
106
107
108
109
110
111
112
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
113
114
115
116
117
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

118
    if "facebook/galactica" in model_id:
119
        return GalacticaSharded(
120
121
122
123
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
124
            trust_remote_code=trust_remote_code,
125
        )
126

127
    if model_id.startswith("bigcode/"):
128
        if FLASH_ATTENTION:
129
130
131
132
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
133
                dtype=dtype,
134
135
                trust_remote_code=trust_remote_code,
            )
136
137
138
139
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
140
        else:
141
            return SantaCoder(
142
143
144
                model_id,
                revision,
                quantize=quantize,
145
                dtype=dtype,
146
147
                trust_remote_code=trust_remote_code,
            )
148

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
149
150
151
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
152
153
154
155
156
157
158
159
160

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
161
162
163
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
199

200
    if model_type == "gpt_bigcode":
201
        if FLASH_ATTENTION:
202
203
204
205
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
206
                dtype=dtype,
207
208
                trust_remote_code=trust_remote_code,
            )
209
210
211
212
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
213
        else:
214
            return SantaCoder(
215
216
217
                model_id,
                revision,
                quantize=quantize,
218
                dtype=dtype,
219
220
                trust_remote_code=trust_remote_code,
            )
221

222
    if model_type == "bloom":
223
        return BLOOMSharded(
224
225
226
227
228
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
229
        )
230
231
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
232
233
234
235
236
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
237
        )
238
239
240
241
242
243
244

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
245
                dtype=dtype,
246
247
248
249
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
250
251
252
                model_id,
                revision,
                quantize=quantize,
253
                dtype=dtype,
254
255
                trust_remote_code=trust_remote_code,
            )
256
        else:
257
            return CausalLM(
258
259
260
                model_id,
                revision,
                quantize=quantize,
261
                dtype=dtype,
262
263
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
264

drbh's avatar
drbh committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                use_medusa=use_medusa,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
286
287
288
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
289
290
291
292
293
294
295
296
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
297

xiaobin's avatar
xiaobin committed
298
    elif model_type == "llama" or model_type == "baichuan":
299
300
        if FLASH_ATTENTION:
            return FlashLlama(
301
302
303
                model_id,
                revision,
                quantize=quantize,
304
                dtype=dtype,
305
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
306
                use_medusa=use_medusa,
307
            )
308
309
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
310
        else:
311
            return CausalLM(
312
313
314
                model_id,
                revision,
                quantize=quantize,
315
                dtype=dtype,
316
317
                trust_remote_code=trust_remote_code,
            )
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                use_medusa=use_medusa,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
340

341
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
342
343
        if sharded:
            if FLASH_ATTENTION:
344
                if config_dict.get("alibi", False):
345
346
347
348
349
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
350
                    dtype=dtype,
351
352
                    trust_remote_code=trust_remote_code,
                )
353
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
354
        else:
355
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
356
                return FlashRWSharded(
357
358
359
                    model_id,
                    revision,
                    quantize=quantize,
360
                    dtype=dtype,
361
362
363
364
365
366
367
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
368
                    dtype=dtype,
369
370
371
                    trust_remote_code=trust_remote_code,
                )

372
    if model_type == "mistral":
373
374
375
376
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
377
378
379
380
381
382
383
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
384
385

    if model_type == "mixtral":
386
387
388
389
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
390
391
392
393
394
395
396
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
397
398

    if model_type == "opt":
399
        return OPTSharded(
400
401
402
403
404
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
405
        )
406

407
    if model_type == "t5":
408
409
410
411
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
412
            dtype=dtype,
413
414
            trust_remote_code=trust_remote_code,
        )
415
    if model_type == "idefics":
416
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
417
418
419
420
421
422
423
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
424
425
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
426
427

    if sharded:
428
        raise NotImplementedError("sharded is not supported for AutoModel")
429
    if quantize == "gptq":
430
        raise NotImplementedError(
431
432
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
433
    if quantize == "awq":
434
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
435
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
436
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
437
    elif quantize == "eetq":
438
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
439
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
440
        return CausalLM(
441
442
443
444
445
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
446
        )
447
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
448
        return Seq2SeqLM(
449
450
451
452
453
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
454
455
        )

456
    auto_map = config_dict.get("auto_map", None)
457
458
459
460
461
462
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
463
                dtype=dtype,
464
465
                trust_remote_code=trust_remote_code,
            )
466
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
467
468
469
470
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
471
                dtype=dtype,
472
473
                trust_remote_code=trust_remote_code,
            )
474
475

    raise ValueError(f"Unsupported model type {model_type}")