hf_predictor.py 7.05 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from io import BytesIO
from typing import Iterable, List, Optional, Union

import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, BitsAndBytesConfig

from ...model.vlm_hf_model import Mineru2QwenForCausalLM
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
from .base_predictor import (
    DEFAULT_MAX_NEW_TOKENS,
    DEFAULT_NO_REPEAT_NGRAM_SIZE,
    DEFAULT_PRESENCE_PENALTY,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    BasePredictor,
)
from .utils import load_resource


class HuggingfacePredictor(BasePredictor):
    def __init__(
        self,
        model_path: str,
        device_map="auto",
        device="cuda",
        torch_dtype="auto",
        load_in_8bit=False,
        load_in_4bit=False,
        use_flash_attn=False,
        temperature: float = DEFAULT_TEMPERATURE,
        top_p: float = DEFAULT_TOP_P,
        top_k: int = DEFAULT_TOP_K,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
        **kwargs,
    ):
        super().__init__(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            presence_penalty=presence_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
        )

        kwargs = {"device_map": device_map, **kwargs}

        if device != "cuda":
            kwargs["device_map"] = {"": device}

        if load_in_8bit:
            kwargs["load_in_8bit"] = True
        elif load_in_4bit:
            kwargs["load_in_4bit"] = True
            kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
        else:
            kwargs["torch_dtype"] = torch_dtype

        if use_flash_attn:
            kwargs["attn_implementation"] = "flash_attention_2"

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = Mineru2QwenForCausalLM.from_pretrained(
            model_path,
            low_cpu_mem_usage=True,
            **kwargs,
        )
        setattr(self.model.config, "_name_or_path", model_path)
        self.model.eval()

        vision_tower = self.model.get_model().vision_tower
        if device_map != "auto":
            vision_tower.to(device=device_map, dtype=self.model.dtype)

        self.image_processor = vision_tower.image_processor
        self.eos_token_id = self.model.config.eos_token_id

    def predict(
        self,
        image: str | bytes,
        prompt: str = "",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        repetition_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        **kwargs,
    ) -> str:
        prompt = self.build_prompt(prompt)

        if temperature is None:
            temperature = self.temperature
        if top_p is None:
            top_p = self.top_p
        if top_k is None:
            top_k = self.top_k
        if repetition_penalty is None:
            repetition_penalty = self.repetition_penalty
        if no_repeat_ngram_size is None:
            no_repeat_ngram_size = self.no_repeat_ngram_size
        if max_new_tokens is None:
            max_new_tokens = self.max_new_tokens

        do_sample = (temperature > 0.0) and (top_k > 1)

        generate_kwargs = {
            "repetition_penalty": repetition_penalty,
            "no_repeat_ngram_size": no_repeat_ngram_size,
            "max_new_tokens": max_new_tokens,
            "do_sample": do_sample,
        }
        if do_sample:
            generate_kwargs["temperature"] = temperature
            generate_kwargs["top_p"] = top_p
            generate_kwargs["top_k"] = top_k

        if isinstance(image, str):
            image = load_resource(image)

        image_obj = Image.open(BytesIO(image))
        image_tensor = process_images([image_obj], self.image_processor, self.model.config)
        image_tensor = image_tensor[0].unsqueeze(0)
        image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
        image_sizes = [[*image_obj.size]]

        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        input_ids = input_ids.to(device=self.model.device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor,
                image_sizes=image_sizes,
                use_cache=True,
                **generate_kwargs,
                **kwargs,
            )

        # Remove the last token if it is the eos_token_id
        if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
            output_ids = output_ids[:, :-1]

        output = self.tokenizer.batch_decode(
            output_ids,
            skip_special_tokens=False,
        )[0].strip()

        return output

    def batch_predict(
        self,
        images: List[str] | List[bytes],
        prompts: Union[List[str], str] = "",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        repetition_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,  # not supported by hf
        no_repeat_ngram_size: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        **kwargs,
    ) -> List[str]:
        if not isinstance(prompts, list):
            prompts = [prompts] * len(images)

        assert len(prompts) == len(images), "Length of prompts and images must match."

        outputs = []
        for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
            output = self.predict(
                image,
                prompt,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                repetition_penalty=repetition_penalty,
                presence_penalty=presence_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                max_new_tokens=max_new_tokens,
                **kwargs,
            )
            outputs.append(output)
        return outputs

    def stream_predict(
        self,
        image: str | bytes,
        prompt: str = "",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        repetition_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
    ) -> Iterable[str]:
        raise NotImplementedError("Streaming is not supported yet.")