__init__.py 741 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from text_generation.models.model import Model
from text_generation.models.bloom import BLOOMSharded

__all__ = ["Model", "BLOOMSharded"]


def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:

    if model_name.startswith("bigscience/bloom"):
        if sharded:
            return BLOOMSharded(model_name, quantize)
        else:
            if quantize:
                raise ValueError("quantization is not supported for non-sharded BLOOM")
            return Model(model_name)
    else:
        if sharded:
18
            raise ValueError("sharded is only supported for BLOOM models")
19
20
21
22
        if quantize:
            raise ValueError("Quantization is only supported for BLOOM models")

        return Model(model_name)