Unverified Commit 40213c95 authored by drbh's avatar drbh Committed by GitHub
Browse files

Pali gemma modeling (#1895)

This PR adds paligemma modeling code

Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814

install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf

# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```


basic example sending various requests
```python
from huggingface_hub import InferenceClient

client = InferenceClient("http://127.0.0.1:3000")


images = [
    "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png

",
]

prompts = [
    "What animal is in this image?",
    "Name three colors in this image.",
    "What are 10 colors in this image?",
    "Where is the cow standing?",
    "answer en Where is the cow standing?",
    "Is there a bird in the image?",
    "Is ther a cow in the image?",
    "Is there a rabbit in the image?",
    "how many birds are in the image?",
    "how many rabbits are in the image?",
]

for img in images:
    print(f"\nImage: {img.split('/')[-1]}")
    for prompt in prompts:
        inputs = f"![]({img}){prompt}\n"
        json_data = {
            "inputs": inputs,
            "parameters": {
                "max_new_tokens": 30,
                "do_sample": False,
            },
        }
        generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
        print([f"{prompt}\n{generated_output}"])

```

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 6c715f81
......@@ -15,6 +15,7 @@ from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
......@@ -80,6 +81,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
......@@ -193,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
......
......@@ -14,7 +14,10 @@ from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
......@@ -98,6 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
......@@ -122,6 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in {
IdeficsCausalLMBatch,
VlmCausalLMBatch,
PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
......
......@@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
......@@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
0.0,
softmax_scale,
False,
True,
causal,
window_size_left,
0,
False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment