__init__.py 595 Bytes
Newer Older
1
from text_generation.models.model import Model
2
from text_generation.models.bloom import BLOOM, BLOOMSharded
3

4
__all__ = ["Model", "BLOOM", "BLOOMSharded"]
5
6
7
8
9
10
11
12
13


def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
    if model_name.startswith("bigscience/bloom"):
        if sharded:
            return BLOOMSharded(model_name, quantize)
        else:
            if quantize:
                raise ValueError("quantization is not supported for non-sharded BLOOM")
14
            return BLOOM(model_name)
15
    else:
16
        raise ValueError(f"model {model_name} is not supported yet")