Unverified Commit 40213c95 authored by drbh's avatar drbh Committed by GitHub
Browse files

Pali gemma modeling (#1895)

This PR adds paligemma modeling code

Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814

install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf

# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```


basic example sending various requests
```python
from huggingface_hub import InferenceClient

client = InferenceClient("http://127.0.0.1:3000")


images = [
    "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png

",
]

prompts = [
    "What animal is in this image?",
    "Name three colors in this image.",
    "What are 10 colors in this image?",
    "Where is the cow standing?",
    "answer en Where is the cow standing?",
    "Is there a bird in the image?",
    "Is ther a cow in the image?",
    "Is there a rabbit in the image?",
    "how many birds are in the image?",
    "how many rabbits are in the image?",
]

for img in images:
    print(f"\nImage: {img.split('/')[-1]}")
    for prompt in prompts:
        inputs = f"![]({img}){prompt}\n"
        json_data = {
            "inputs": inputs,
            "parameters": {
                "max_new_tokens": 30,
                "do_sample": False,
            },
        }
        generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
        print([f"{prompt}\n{generated_output}"])

```

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 6c715f81
...@@ -27,7 +27,7 @@ jobs: ...@@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
AWS_REGION: us-east-1 AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002 EC2_AMI_ID: ami-0789b6925c11b1fb2
EC2_INSTANCE_TYPE: g5.12xlarge EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6 EC2_SECURITY_GROUP: sg-030175c435ac141d6
......
...@@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0 ...@@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=23.3.1-1 ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx # Automatically set by buildx
...@@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins ...@@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
ca-certificates \ ca-certificates \
make \ make \
curl \ curl \
git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed # Copy conda with PyTorch installed
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 54901,
"logprob": -0.72753906,
"special": false,
"text": "beach"
},
{
"id": 1,
"logprob": -0.011009216,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "beach"
}
import pytest
import requests
import io
import base64
@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",
num_shard=1,
revision="float16",
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach"
assert response == response_snapshot
...@@ -100,7 +100,6 @@ impl LlavaNext { ...@@ -100,7 +100,6 @@ impl LlavaNext {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct ClipVisionModel { pub struct ClipVisionModel {
image_size: usize, image_size: usize,
...@@ -108,7 +107,6 @@ pub struct ClipVisionModel { ...@@ -108,7 +107,6 @@ pub struct ClipVisionModel {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {}
...@@ -118,6 +116,24 @@ impl Idefics2 { ...@@ -118,6 +116,24 @@ impl Idefics2 {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}
impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
self.text_config.num_image_tokens
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -140,6 +156,7 @@ pub enum Config { ...@@ -140,6 +156,7 @@ pub enum Config {
Phi3, Phi3,
Llama, Llama,
Baichuan, Baichuan,
Paligemma(Paligemma),
Gemma, Gemma,
Cohere, Cohere,
Drbx, Drbx,
......
...@@ -544,6 +544,30 @@ fn prepare_input( ...@@ -544,6 +544,30 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics2(config)) => { Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
......
This diff is collapsed.
...@@ -25,8 +25,9 @@ opentelemetry-instrumentation-grpc = "^0.36b0" ...@@ -25,8 +25,9 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.19.1" tokenizers = "^0.19.1"
huggingface-hub = "^0.19.3" huggingface-hub = "^0.23"
transformers = "^4.40" # transformers = "^4.40"
transformers = { git = "https://github.com/huggingface/transformers.git", rev="b8aee2e" }
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }
......
...@@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" ...@@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...@@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" ...@@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
......
...@@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" ...@@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...@@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" ...@@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
......
...@@ -10,9 +10,9 @@ class FastLinear(torch.nn.Module): ...@@ -10,9 +10,9 @@ class FastLinear(torch.nn.Module):
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = torch.nn.Parameter(weight) self.weight = torch.nn.Parameter(weight, requires_grad=False)
if bias is not None: if bias is not None:
self.bias = torch.nn.Parameter(bias) self.bias = torch.nn.Parameter(bias, requires_grad=False)
else: else:
self.bias = None self.bias = None
......
...@@ -65,6 +65,9 @@ try: ...@@ -65,6 +65,9 @@ try:
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.flash_gemma import (
FlashGemma, FlashGemma,
) )
from text_generation_server.models.pali_gemma import (
PaliGemma,
)
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
) )
...@@ -676,6 +679,18 @@ def get_model( ...@@ -676,6 +679,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if FLASH_ATTENTION:
return PaliGemma(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:
......
...@@ -99,8 +99,13 @@ class GemmaConfig(PretrainedConfig): ...@@ -99,8 +99,13 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps) weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after # perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
...@@ -111,7 +116,7 @@ class GemmaFastRMSNorm(FastRMSNorm): ...@@ -111,7 +116,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
return hidden_states.to(self.weight.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
...@@ -153,15 +158,11 @@ def _load_gqa(config, prefix: str, weights): ...@@ -153,15 +158,11 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__( def __init__(self, prefix: str, config, weights, causal: bool):
self,
prefix: str,
config,
weights,
):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
...@@ -238,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -238,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
causal=self.causal,
) )
# Decode # Decode
else: else:
...@@ -295,11 +297,10 @@ class GemmaMLP(nn.Module): ...@@ -295,11 +297,10 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
) )
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
...@@ -351,30 +352,25 @@ class FlashGemmaLayer(nn.Module): ...@@ -351,30 +352,25 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGemmaLayer( FlashGemmaLayer(
layer_id, prefix=f"{prefix}.layers.{layer_id}",
config, config=config,
weights, weights=weights,
causal=causal,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = GemmaFastRMSNorm.load( self.norm = GemmaFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -385,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -385,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
...@@ -394,7 +390,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -394,7 +390,7 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
...@@ -423,13 +419,30 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -423,13 +419,30 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
self.model = FlashGemmaModel(config, weights) embed_norm = config.hidden_size**0.5
if prefix is None:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix=(
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights, weights=weights,
) )
...@@ -445,8 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -445,8 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
prefix="multi_modal_projector.linear",
weights=weights,
bias=True,
)
self.vocab_size = config.vocab_size
self.config = config
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
This diff is collapsed.
...@@ -11,6 +11,18 @@ def load_text_model(prefix, config, weights, name=None): ...@@ -11,6 +11,18 @@ def load_text_model(prefix, config, weights, name=None):
) )
return FlashMistralForCausalLM(prefix, config, weights, name=name) return FlashMistralForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "gemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
...@@ -24,5 +36,13 @@ def load_vision_model(prefix, config, weights): ...@@ -24,5 +36,13 @@ def load_vision_model(prefix, config, weights):
return CLIPVisionTransformer( return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights prefix=f"{prefix}.vision_model", config=config, weights=weights
) )
if config.model_type == "siglip_vision_model":
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights
)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
...@@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch): ...@@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch):
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
...@@ -207,6 +218,7 @@ class FlashCausalLMBatch(Batch): ...@@ -207,6 +218,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
......
...@@ -3,12 +3,11 @@ import torch.distributed ...@@ -3,12 +3,11 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
GemmaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
...@@ -36,17 +35,15 @@ class FlashGemma(FlashCausalLM): ...@@ -36,17 +35,15 @@ class FlashGemma(FlashCausalLM):
else: else:
raise NotImplementedError("FlashGemma is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
) )
config = GemmaConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
...@@ -59,7 +56,9 @@ class FlashGemma(FlashCausalLM): ...@@ -59,7 +56,9 @@ class FlashGemma(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights) # TODO hardcoded
prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(
......
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment