__init__.py 14.2 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
21
from text_generation_server.models.phi import Phi
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

46
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
47

48
FLASH_ATTENTION = True
49
try:
50
51
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
57
    )
58
    from text_generation_server.models.idefics import IDEFICSSharded
59
60
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
61
    from text_generation_server.models.flash_phi import FlashPhi
62
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
63
64
65

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
66
    FLASH_ATTENTION = False
67
    HAS_FLASH_ATTN_V2_CUDA = False
68

69
if FLASH_ATTENTION:
70
    __all__.append(FlashNeoXSharded)
71
    __all__.append(FlashRWSharded)
72
    __all__.append(FlashSantacoderSharded)
73
    __all__.append(FlashLlama)
74
    __all__.append(IDEFICSSharded)
75
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
76
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
77
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
78

drbh's avatar
drbh committed
79
80
81
82
83
84
85
86
87
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
88

89
def get_model(
90
91
92
93
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
94
    speculate: Optional[int],
95
    dtype: Optional[str],
96
    trust_remote_code: bool,
97
) -> Model:
98
    if dtype is None:
99
100
101
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
102
103
104
105
106
107
108
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
109
110
111
112
113
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

114
    if "facebook/galactica" in model_id:
115
        return GalacticaSharded(
116
117
118
119
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
120
            trust_remote_code=trust_remote_code,
121
        )
122

123
    if model_id.startswith("bigcode/"):
124
        if FLASH_ATTENTION:
125
126
127
128
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
129
                dtype=dtype,
130
131
                trust_remote_code=trust_remote_code,
            )
132
133
134
135
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
136
        else:
137
            return SantaCoder(
138
139
140
                model_id,
                revision,
                quantize=quantize,
141
                dtype=dtype,
142
143
                trust_remote_code=trust_remote_code,
            )
144

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
145
146
147
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
148
149
150
151
152
153
154
155
156

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
157
158
159
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
195

196
    if model_type == "gpt_bigcode":
197
        if FLASH_ATTENTION:
198
199
200
201
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
202
                dtype=dtype,
203
204
                trust_remote_code=trust_remote_code,
            )
205
206
207
208
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
209
        else:
210
            return SantaCoder(
211
212
213
                model_id,
                revision,
                quantize=quantize,
214
                dtype=dtype,
215
216
                trust_remote_code=trust_remote_code,
            )
217

218
    if model_type == "bloom":
219
        return BLOOMSharded(
220
221
222
223
224
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
225
        )
226
227
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
228
229
230
231
232
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
233
        )
234
235
236
237
238
239
240

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
241
                dtype=dtype,
242
243
244
245
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
246
247
248
                model_id,
                revision,
                quantize=quantize,
249
                dtype=dtype,
250
251
                trust_remote_code=trust_remote_code,
            )
252
        else:
253
            return CausalLM(
254
255
256
                model_id,
                revision,
                quantize=quantize,
257
                dtype=dtype,
258
259
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
260

drbh's avatar
drbh committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                use_medusa=use_medusa,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
282
283
284
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
285
286
287
288
289
290
291
292
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
293

xiaobin's avatar
xiaobin committed
294
    elif model_type == "llama" or model_type == "baichuan":
295
296
        if FLASH_ATTENTION:
            return FlashLlama(
297
298
299
                model_id,
                revision,
                quantize=quantize,
300
                dtype=dtype,
301
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
302
                use_medusa=use_medusa,
303
            )
304
305
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
306
        else:
307
            return CausalLM(
308
309
310
                model_id,
                revision,
                quantize=quantize,
311
                dtype=dtype,
312
313
                trust_remote_code=trust_remote_code,
            )
314

315
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
316
317
        if sharded:
            if FLASH_ATTENTION:
318
                if config_dict.get("alibi", False):
319
320
321
322
323
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
324
                    dtype=dtype,
325
326
                    trust_remote_code=trust_remote_code,
                )
327
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
328
        else:
329
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
330
                return FlashRWSharded(
331
332
333
                    model_id,
                    revision,
                    quantize=quantize,
334
                    dtype=dtype,
335
336
337
338
339
340
341
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
342
                    dtype=dtype,
343
344
345
                    trust_remote_code=trust_remote_code,
                )

346
    if model_type == "mistral":
347
348
349
350
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
351
352
353
354
355
356
357
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
358
359

    if model_type == "mixtral":
360
361
362
363
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
364
365
366
367
368
369
370
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
371
372

    if model_type == "opt":
373
        return OPTSharded(
374
375
376
377
378
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
379
        )
380

381
    if model_type == "t5":
382
383
384
385
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
386
            dtype=dtype,
387
388
            trust_remote_code=trust_remote_code,
        )
389
    if model_type == "idefics":
390
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
391
392
393
394
395
396
397
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
398
399
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
400
401

    if sharded:
402
        raise NotImplementedError("sharded is not supported for AutoModel")
403
    if quantize == "gptq":
404
        raise NotImplementedError(
405
406
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
407
    if quantize == "awq":
408
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
409
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
410
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
411
    elif quantize == "eetq":
412
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
413
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
414
        return CausalLM(
415
416
417
418
419
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
420
        )
421
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
422
        return Seq2SeqLM(
423
424
425
426
427
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
428
429
        )

430
    auto_map = config_dict.get("auto_map", None)
431
432
433
434
435
436
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
437
                dtype=dtype,
438
439
                trust_remote_code=trust_remote_code,
            )
440
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
441
442
443
444
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
445
                dtype=dtype,
446
447
                trust_remote_code=trust_remote_code,
            )
448
449

    raise ValueError(f"Unsupported model type {model_type}")