__init__.py 4.47 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers import AutoConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
13
14
15
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
16
from text_generation_server.models.gpt_neox import GPTNeoxSharded
17
from text_generation_server.models.t5 import T5Sharded
18

19
20
try:
    from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
21
    from text_generation_server.models.flash_santacoder import FlashSantacoder
22
    from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
23

24
    FLASH_ATTENTION = torch.cuda.is_available()
25
except ImportError:
26
    logger.exception("Could not import Flash Attention enabled models")
27
    FLASH_ATTENTION = False
28

29
30
31
32
33
__all__ = [
    "Model",
    "BLOOM",
    "BLOOMSharded",
    "CausalLM",
34
    "FlashCausalLM",
35
36
37
    "Galactica",
    "GalacticaSharded",
    "GPTNeoxSharded",
38
39
    "Seq2SeqLM",
    "SantaCoder",
40
    "T5Sharded",
41
42
43
    "get_model",
]

44
if FLASH_ATTENTION:
45
46
    __all__.append(FlashNeoX)
    __all__.append(FlashNeoXSharded)
47
    __all__.append(FlashSantacoder)
48
49
50
51
52
53
    __all__.append(FlashLlama)
    __all__.append(FlashLlamaSharded)

FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \
                          "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \
                          "or install flash attention with `cd server && make install install-flash-attention`"
54

55
56
57
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
58

59
60
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
61

62
63
64
# Disable gradients
torch.set_grad_enabled(False)

65

66
def get_model(
67
        model_id: str, revision: Optional[str], sharded: bool, quantize: bool
68
) -> Model:
69
    if "facebook/galactica" in model_id:
70
71
72
73
74
75
        if sharded:
            return GalacticaSharded(model_id, revision, quantize=quantize)
        else:
            return Galactica(model_id, revision, quantize=quantize)

    if "santacoder" in model_id:
76
77
78
79
80
        if sharded:
            raise NotImplementedError("sharded is not supported for Santacoder")
        else:
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
            return santacoder_cls(model_id, revision, quantize)
81

82
    config = AutoConfig.from_pretrained(model_id, revision=revision)
83
    model_type = config.model_type
84

85
    if model_type == "bloom":
86
        if sharded:
87
            return BLOOMSharded(model_id, revision, quantize=quantize)
88
        else:
89
            return BLOOM(model_id, revision, quantize=quantize)
90

91
    if model_type == "gpt_neox":
92
        if sharded:
93
            neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
94
            return neox_cls(model_id, revision, quantize=quantize)
95
        else:
96
            neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
97
            return neox_cls(model_id, revision, quantize=quantize)
98

99
100
101
102
103
104
105
106
107
108
109
    if model_type == "llama":
        if sharded:
            if FLASH_ATTENTION:
                return FlashLlamaSharded(model_id, revision, quantize=quantize)
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")
            )
        else:
            llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
            return llama_cls(model_id, revision, quantize=quantize)

110
    if model_type == "t5":
111
112
113
114
        if sharded:
            return T5Sharded(model_id, revision, quantize=quantize)
        else:
            return Seq2SeqLM(model_id, revision, quantize=quantize)
115
116
117

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
118
119

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
120
        return CausalLM(model_id, revision, quantize=quantize)
121
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
122
        return Seq2SeqLM(model_id, revision, quantize=quantize)
123
124

    raise ValueError(f"Unsupported model type {model_type}")