"tests/vscode:/vscode.git/clone" did not exist on "7536f647e4144c7acaf9e140893ff7edb85bf9a3"
pali_gemma.py 4.37 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
    VlmCausalLM,
    VlmCausalLMBatch,
    image_text_replacement,
    load_data_uri,
    split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
    PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
    @classmethod
    def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
        batch_inputs = []
        image_inputs = []
        max_truncation = 0
        for r in requests:
            chunks = split(r.inputs)
            full_text = ""
            image_id = 0
            for chunk in chunks:
                if chunk["type"] == "text":
                    full_text += "<bos>" + chunk["content"] + "\n"
                elif chunk["type"] == "image":
                    image = chunk["content"]
                    # Should never receive URLs anymore, processing should be done
                    # On the rust layer.
                    # This avoid making n queries per TP
                    # if image.startswith("https://") or image.startswith("http://"):
                    #     image = processor.image_processor.fetch_images(image)
                    if image.startswith("data:"):
                        image = load_data_uri(image)
                    else:
                        raise RuntimeError(
                            "Cannot process input image not starting with data:"
                        )
                    # TODO do_convert_RGB should be on by default ?
                    image = image.convert("RGB")
                    image_input = processor.image_processor(image, return_tensors="pt")
                    full_text += image_text_replacement(image_input, config, image_id)
                    image_inputs.append(image_input)
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk['type']}")

            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=False,
        )["input_ids"]
        if image_inputs:
            image_input = image_inputs[0]
            new_image_inputs = {
                "pixel_values": torch.cat(
                    [img["pixel_values"] for img in image_inputs], dim=0
                ),
            }
            if "pixel_attention_mask" in image_input:
                new_image_inputs["pixel_attention_mask"] = torch.cat(
                    [img["pixel_attention_mask"] for img in image_inputs], dim=0
                )
            if "image_sizes" in image_input:
                new_image_inputs["image_sizes"] = torch.cat(
                    [img["image_sizes"] for img in image_inputs], dim=0
                )
            image_inputs = new_image_inputs
        else:
            image_inputs = None
        return batch_tokenized_inputs, image_inputs


class PaliGemma(VlmCausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        self.processor = AutoProcessor.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

        super().__init__(
            config_cls=AutoConfig,
            model_cls=PaliGemmaForConditionalGeneration,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

    @property
    def batch_type(self):
        return PaliGemmaBatch

    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.text_model.model.layers),
            model.text_model.model.num_key_value_heads,
            model.text_model.model.head_size,
        )

    def max_past(self) -> Optional[int]:
        return getattr(self.model.text_model, "max_past", None)