vlm.py 1.81 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
def load_text_model(prefix, config, weights, name=None):
    if config.model_type == "llama":
        from text_generation_server.models.custom_modeling.flash_llama_modeling import (
            FlashLlamaForCausalLM,
        )

        return FlashLlamaForCausalLM(prefix, config, weights)
    elif config.model_type == "mistral":
        from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
            FlashMistralForCausalLM,
        )

        return FlashMistralForCausalLM(prefix, config, weights, name=name)
drbh's avatar
drbh committed
14
15
16
17
18
19
20
21
22
23
24
25
    elif config.model_type == "gemma":
        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
            FlashGemmaForCausalLM,
        )

        return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
    elif config.model_type == "paligemma":
        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
            FlashGemmaForCausalLM,
        )

        return FlashGemmaForCausalLM(prefix, config, weights)
Nicolas Patry's avatar
Nicolas Patry committed
26
27
28
29
30
31
32
33
34
35
36
37
38
    else:
        raise RuntimeError(f"Unsupported model type {config.model_type}")


def load_vision_model(prefix, config, weights):
    if config.model_type == "clip_vision_model":
        from text_generation_server.models.custom_modeling.clip import (
            CLIPVisionTransformer,
        )

        return CLIPVisionTransformer(
            prefix=f"{prefix}.vision_model", config=config, weights=weights
        )
drbh's avatar
drbh committed
39
40
41
42
43
44
45
46
    if config.model_type == "siglip_vision_model":
        from text_generation_server.models.custom_modeling.siglip import (
            SiglipVisionTransformer,
        )

        return SiglipVisionTransformer(
            prefix=f"vision_tower.vision_model", config=config, weights=weights
        )
Nicolas Patry's avatar
Nicolas Patry committed
47
48
    else:
        raise RuntimeError(f"Unsupported model type {config.model_type}")