Unverified Commit 9f5c9a5e authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable paligemma2 (#2807)

* feat: support loading gemma2 as vlm text model

* feat: add test for paligemma2
parent 08f6fa0b
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 108,
"logprob": -0.73046875,
"special": false,
"text": "\n"
},
{
"id": 30234,
"logprob": -2.328125,
"special": false,
"text": "Brown"
},
{
"id": 108,
"logprob": -0.12060547,
"special": false,
"text": "\n"
},
{
"id": 3726,
"logprob": -1.7734375,
"special": false,
"text": "Car"
},
{
"id": 108,
"logprob": -0.041503906,
"special": false,
"text": "\n"
},
{
"id": 2915,
"logprob": -1.796875,
"special": false,
"text": "Color"
},
{
"id": 108,
"logprob": -0.039794922,
"special": false,
"text": "\n"
},
{
"id": 19178,
"logprob": -1.96875,
"special": false,
"text": "Cool"
},
{
"id": 108,
"logprob": -0.080566406,
"special": false,
"text": "\n"
},
{
"id": 40544,
"logprob": -2.1875,
"special": false,
"text": "Decor"
},
{
"id": 108,
"logprob": -0.033935547,
"special": false,
"text": "\n"
},
{
"id": 13936,
"logprob": -1.6328125,
"special": false,
"text": "Green"
},
{
"id": 108,
"logprob": -0.16210938,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -2.015625,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.14746094,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.73828125,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.051513672,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.34765625,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.020141602,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.11767578,
"special": false,
"text": "..."
}
],
"top_tokens": null
},
"generated_text": "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
}
import pytest
@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma2-3b-pt-224",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client
async def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot):
car_image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
response = await flash_pali_gemma.generate(
f"![]({car_image})",
max_new_tokens=20,
)
assert (
response.generated_text
== "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
)
assert response == response_snapshot
...@@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None): ...@@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None):
) )
return FlashGemmaForCausalLM(prefix, config, weights, causal=False) return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment