"examples/community/composable_stable_diffusion.py" did not exist on "ca749513230f5c425049d47c3dad9c099fb2a769"
__init__.py 13.5 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
21
from text_generation_server.models.phi import Phi
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

46
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
47

48
FLASH_ATTENTION = True
49
try:
50
51
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
57
    )
58
    from text_generation_server.models.idefics import IDEFICSSharded
59
60
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
61
    from text_generation_server.models.flash_phi import FlashPhi
62
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
63
64
65

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
66
    FLASH_ATTENTION = False
67
    HAS_FLASH_ATTN_V2_CUDA = False
68

69
if FLASH_ATTENTION:
70
    __all__.append(FlashNeoXSharded)
71
    __all__.append(FlashRWSharded)
72
    __all__.append(FlashSantacoderSharded)
73
    __all__.append(FlashLlama)
74
    __all__.append(IDEFICSSharded)
75
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
76
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
77
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
78
79


80
def get_model(
81
82
83
84
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
85
    speculate: Optional[int],
86
    dtype: Optional[str],
87
    trust_remote_code: bool,
88
) -> Model:
89
    if dtype is None:
90
91
92
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
93
94
95
96
97
98
99
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
100
101
102
103
104
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

105
    if "facebook/galactica" in model_id:
106
        return GalacticaSharded(
107
108
109
110
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
111
            trust_remote_code=trust_remote_code,
112
        )
113

114
    if model_id.startswith("bigcode/"):
115
        if FLASH_ATTENTION:
116
117
118
119
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
120
                dtype=dtype,
121
122
                trust_remote_code=trust_remote_code,
            )
123
124
125
126
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
127
        else:
128
            return SantaCoder(
129
130
131
                model_id,
                revision,
                quantize=quantize,
132
                dtype=dtype,
133
134
                trust_remote_code=trust_remote_code,
            )
135

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
136
137
138
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
139
140
141
142
143
144
145
146
147

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
148
149
150
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

167
    model_type = config_dict["model_type"]
168

169
    if model_type == "gpt_bigcode":
170
        if FLASH_ATTENTION:
171
172
173
174
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
175
                dtype=dtype,
176
177
                trust_remote_code=trust_remote_code,
            )
178
179
180
181
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
182
        else:
183
            return SantaCoder(
184
185
186
                model_id,
                revision,
                quantize=quantize,
187
                dtype=dtype,
188
189
                trust_remote_code=trust_remote_code,
            )
190

191
    if model_type == "bloom":
192
        return BLOOMSharded(
193
194
195
196
197
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
198
        )
199
200
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
201
202
203
204
205
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
206
        )
207
208
209
210
211
212
213

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
214
                dtype=dtype,
215
216
217
218
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
219
220
221
                model_id,
                revision,
                quantize=quantize,
222
                dtype=dtype,
223
224
                trust_remote_code=trust_remote_code,
            )
225
        else:
226
            return CausalLM(
227
228
229
                model_id,
                revision,
                quantize=quantize,
230
                dtype=dtype,
231
232
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
233

drbh's avatar
drbh committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                use_medusa=use_medusa,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
255
256
257
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
258
259
260
261
262
263
264
265
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
266

xiaobin's avatar
xiaobin committed
267
    elif model_type == "llama" or model_type == "baichuan":
268
269
        if FLASH_ATTENTION:
            return FlashLlama(
270
271
272
                model_id,
                revision,
                quantize=quantize,
273
                dtype=dtype,
274
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
275
                use_medusa=use_medusa,
276
            )
277
278
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
279
        else:
280
            return CausalLM(
281
282
283
                model_id,
                revision,
                quantize=quantize,
284
                dtype=dtype,
285
286
                trust_remote_code=trust_remote_code,
            )
287

288
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
289
290
        if sharded:
            if FLASH_ATTENTION:
291
                if config_dict.get("alibi", False):
292
293
294
295
296
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
297
                    dtype=dtype,
298
299
                    trust_remote_code=trust_remote_code,
                )
300
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
301
        else:
302
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
303
                return FlashRWSharded(
304
305
306
                    model_id,
                    revision,
                    quantize=quantize,
307
                    dtype=dtype,
308
309
310
311
312
313
314
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
315
                    dtype=dtype,
316
317
318
                    trust_remote_code=trust_remote_code,
                )

319
    if model_type == "mistral":
320
321
322
323
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
324
325
326
327
328
329
330
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
331
332

    if model_type == "mixtral":
333
334
335
336
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
337
338
339
340
341
342
343
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
344
345

    if model_type == "opt":
346
        return OPTSharded(
347
348
349
350
351
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
352
        )
353

354
    if model_type == "t5":
355
356
357
358
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
359
            dtype=dtype,
360
361
            trust_remote_code=trust_remote_code,
        )
362
    if model_type == "idefics":
363
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
364
365
366
367
368
369
370
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
371
372
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
373
374

    if sharded:
375
        raise NotImplementedError("sharded is not supported for AutoModel")
376
    if quantize == "gptq":
377
        raise NotImplementedError(
378
379
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
380
    if quantize == "awq":
381
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
382
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
383
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
384
    elif quantize == "eetq":
385
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
386
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
387
        return CausalLM(
388
389
390
391
392
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
393
        )
394
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
395
        return Seq2SeqLM(
396
397
398
399
400
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
401
402
        )

403
    auto_map = config_dict.get("auto_map", None)
404
405
406
407
408
409
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
410
                dtype=dtype,
411
412
                trust_remote_code=trust_remote_code,
            )
413
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
414
415
416
417
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
418
                dtype=dtype,
419
420
                trust_remote_code=trust_remote_code,
            )
421
422

    raise ValueError(f"Unsupported model type {model_type}")