__init__.py 14.2 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
21
from text_generation_server.models.phi import Phi
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

46
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
47

48
FLASH_ATTENTION = True
49
try:
50
51
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
57
    )
58
    from text_generation_server.models.idefics import IDEFICSSharded
59
60
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
61
    from text_generation_server.models.flash_phi import FlashPhi
62
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
63
64
65

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
66
    FLASH_ATTENTION = False
67
    HAS_FLASH_ATTN_V2_CUDA = False
68

69
if FLASH_ATTENTION:
70
    __all__.append(FlashNeoXSharded)
71
    __all__.append(FlashRWSharded)
72
    __all__.append(FlashSantacoderSharded)
73
    __all__.append(FlashLlama)
74
    __all__.append(IDEFICSSharded)
75
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
76
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
77
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
78

drbh's avatar
drbh committed
79
80
81
82
83
84
85
86
87
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
88

89

90
def get_model(
91
92
93
94
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
95
    speculate: Optional[int],
96
    dtype: Optional[str],
97
    trust_remote_code: bool,
98
) -> Model:
99
    if dtype is None:
100
101
102
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
103
104
105
106
107
108
109
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
110
111
112
113
114
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

115
    if "facebook/galactica" in model_id:
116
        return GalacticaSharded(
117
118
119
120
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
121
            trust_remote_code=trust_remote_code,
122
        )
123

124
    if model_id.startswith("bigcode/"):
125
        if FLASH_ATTENTION:
126
127
128
129
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
130
                dtype=dtype,
131
132
                trust_remote_code=trust_remote_code,
            )
133
134
135
136
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
137
        else:
138
            return SantaCoder(
139
140
141
                model_id,
                revision,
                quantize=quantize,
142
                dtype=dtype,
143
144
                trust_remote_code=trust_remote_code,
            )
145

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
146
147
148
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
149
150
151
152
153
154
155
156
157

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
158
159
160
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
196

197
    if model_type == "gpt_bigcode":
198
        if FLASH_ATTENTION:
199
200
201
202
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
203
                dtype=dtype,
204
205
                trust_remote_code=trust_remote_code,
            )
206
207
208
209
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
210
        else:
211
            return SantaCoder(
212
213
214
                model_id,
                revision,
                quantize=quantize,
215
                dtype=dtype,
216
217
                trust_remote_code=trust_remote_code,
            )
218

219
    if model_type == "bloom":
220
        return BLOOMSharded(
221
222
223
224
225
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
226
        )
227
228
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
229
230
231
232
233
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
234
        )
235
236
237
238
239
240
241

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
242
                dtype=dtype,
243
244
245
246
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
247
248
249
                model_id,
                revision,
                quantize=quantize,
250
                dtype=dtype,
251
252
                trust_remote_code=trust_remote_code,
            )
253
        else:
254
            return CausalLM(
255
256
257
                model_id,
                revision,
                quantize=quantize,
258
                dtype=dtype,
259
260
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
261

drbh's avatar
drbh committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                use_medusa=use_medusa,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
283
284
285
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
286
287
288
289
290
291
292
293
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
294

xiaobin's avatar
xiaobin committed
295
    elif model_type == "llama" or model_type == "baichuan":
296
297
        if FLASH_ATTENTION:
            return FlashLlama(
298
299
300
                model_id,
                revision,
                quantize=quantize,
301
                dtype=dtype,
302
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
303
                use_medusa=use_medusa,
304
            )
305
306
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
307
        else:
308
            return CausalLM(
309
310
311
                model_id,
                revision,
                quantize=quantize,
312
                dtype=dtype,
313
314
                trust_remote_code=trust_remote_code,
            )
315

316
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
317
318
        if sharded:
            if FLASH_ATTENTION:
319
                if config_dict.get("alibi", False):
320
321
322
323
324
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
325
                    dtype=dtype,
326
327
                    trust_remote_code=trust_remote_code,
                )
328
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
329
        else:
330
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
331
                return FlashRWSharded(
332
333
334
                    model_id,
                    revision,
                    quantize=quantize,
335
                    dtype=dtype,
336
337
338
339
340
341
342
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
343
                    dtype=dtype,
344
345
346
                    trust_remote_code=trust_remote_code,
                )

347
    if model_type == "mistral":
348
349
350
351
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
352
353
354
355
356
357
358
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
359
360

    if model_type == "mixtral":
361
362
363
364
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
365
366
367
368
369
370
371
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
372
373

    if model_type == "opt":
374
        return OPTSharded(
375
376
377
378
379
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
380
        )
381

382
    if model_type == "t5":
383
384
385
386
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
387
            dtype=dtype,
388
389
            trust_remote_code=trust_remote_code,
        )
390
    if model_type == "idefics":
391
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
392
393
394
395
396
397
398
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
399
400
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
401
402

    if sharded:
403
        raise NotImplementedError("sharded is not supported for AutoModel")
404
    if quantize == "gptq":
405
        raise NotImplementedError(
406
407
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
408
    if quantize == "awq":
409
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
410
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
411
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
412
    elif quantize == "eetq":
413
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
414
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
415
        return CausalLM(
416
417
418
419
420
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
421
        )
422
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
423
        return Seq2SeqLM(
424
425
426
427
428
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
429
430
        )

431
    auto_map = config_dict.get("auto_map", None)
432
433
434
435
436
437
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
438
                dtype=dtype,
439
440
                trust_remote_code=trust_remote_code,
            )
441
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
442
443
444
445
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
446
                dtype=dtype,
447
448
                trust_remote_code=trust_remote_code,
            )
449
450

    raise ValueError(f"Unsupported model type {model_type}")