Unverified Commit bfddfa59 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Idefics2. (#1756)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 986b4044
## Speculation
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
......
......@@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
......@@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
......@@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
......@@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -8.5625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.78125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 288,
"logprob": -0.2854004,
"special": false,
"text": "ing"
},
{
"id": 264,
"logprob": -0.37573242,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.3322754,
"special": false,
"text": " feature"
},
{
"id": 297,
"logprob": -0.8510742,
"special": false,
"text": " in"
},
{
"id": 272,
"logprob": -0.13464355,
"special": false,
"text": " the"
},
{
"id": 2039,
"logprob": 0.0,
"special": false,
"text": " game"
},
{
"id": 28723,
"logprob": -0.89990234,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "Test requesting a new feature in the game.\n\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.13000488,
"special": false,
"text": " A"
},
{
"id": 13088,
"logprob": -0.6713867,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.2980957,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.060638428,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.27319336,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.140625,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.040405273,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002708435,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.095336914,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.0068359375,
"special": false,
"text": "."
}
],
"top_tokens": null
},
"generated_text": " A chicken is sitting on a pile of money."
}
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
"HuggingFaceM4/idefics2-8b",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics2_next(flash_idefics2_next_handle):
await flash_idefics2_next_handle.health(300)
return flash_idefics2_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
response = await flash_idefics2_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == " A chicken is sitting on a pile of money."
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot
......@@ -114,8 +114,12 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
}
requests.push(Request {
id: 0,
......
......@@ -57,6 +57,31 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width))
}
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size;
......@@ -65,11 +90,9 @@ impl LlavaNext {
let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
let (unpadded_features, newline_features) =
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// The base patch covers the entire image
let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features
......@@ -84,6 +107,17 @@ pub struct ClipVisionModel {
patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
......@@ -92,6 +126,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Idefics2(Idefics2),
Ssm,
GptBigcode,
Santacoder,
......@@ -146,13 +181,17 @@ mod test {
],
};
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
}
}
......@@ -540,7 +540,57 @@ fn prepare_input(
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
_ => inputs.clone(),
};
......
......@@ -68,6 +68,7 @@ try:
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
......@@ -579,6 +580,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:
......
......@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
# TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights,
)
self.max_past = config.sliding_window
......
......@@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
......@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
......@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
......@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
......
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
......@@ -511,6 +511,21 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
if cu_seqlen_prefill is None:
logits, speculative_logits = self.compiled_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
else:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
import torch
from typing import Optional
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
......@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
......@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
......@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints,
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
......@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
......@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
......@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
......@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
......@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
......@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
......@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment