Unverified Commit bfddfa59 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Idefics2. (#1756)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 986b4044
## Speculation ## Speculation
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea. Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid. The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
......
...@@ -293,6 +293,7 @@ def launcher(event_loop): ...@@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -334,6 +335,9 @@ def launcher(event_loop): ...@@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
...@@ -371,6 +375,7 @@ def launcher(event_loop): ...@@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -395,6 +400,9 @@ def launcher(event_loop): ...@@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -8.5625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.78125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 288,
"logprob": -0.2854004,
"special": false,
"text": "ing"
},
{
"id": 264,
"logprob": -0.37573242,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.3322754,
"special": false,
"text": " feature"
},
{
"id": 297,
"logprob": -0.8510742,
"special": false,
"text": " in"
},
{
"id": 272,
"logprob": -0.13464355,
"special": false,
"text": " the"
},
{
"id": 2039,
"logprob": 0.0,
"special": false,
"text": " game"
},
{
"id": 28723,
"logprob": -0.89990234,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "Test requesting a new feature in the game.\n\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.13000488,
"special": false,
"text": " A"
},
{
"id": 13088,
"logprob": -0.6713867,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.2980957,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.060638428,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.27319336,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.140625,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.040405273,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002708435,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.095336914,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.0068359375,
"special": false,
"text": "."
}
],
"top_tokens": null
},
"generated_text": " A chicken is sitting on a pile of money."
}
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
"HuggingFaceM4/idefics2-8b",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics2_next(flash_idefics2_next_handle):
await flash_idefics2_next_handle.health(300)
return flash_idefics2_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
response = await flash_idefics2_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == " A chicken is sitting on a pile of money."
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot
...@@ -114,8 +114,12 @@ impl Client { ...@@ -114,8 +114,12 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str("![](");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![]()");
}
requests.push(Request { requests.push(Request {
id: 0, id: 0,
......
...@@ -57,6 +57,31 @@ fn select_best_resolution( ...@@ -57,6 +57,31 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width)) best_fit.unwrap_or((original_height, original_width))
} }
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext { impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size; let image_size = self.vision_config.image_size;
...@@ -65,11 +90,9 @@ impl LlavaNext { ...@@ -65,11 +90,9 @@ impl LlavaNext {
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) = let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width; let (unpadded_features, newline_features) =
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width; get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image // The base patch covers the entire image
let base_features = npatches.pow(2); let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features unpadded_features + newline_features + base_features
...@@ -84,6 +107,17 @@ pub struct ClipVisionModel { ...@@ -84,6 +107,17 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -92,6 +126,7 @@ pub enum Config { ...@@ -92,6 +126,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
Santacoder, Santacoder,
...@@ -146,13 +181,17 @@ mod test { ...@@ -146,13 +181,17 @@ mod test {
], ],
}; };
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640); let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928); assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640); let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340); assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024); let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732); assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899); let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320); assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
} }
} }
...@@ -540,7 +540,57 @@ fn prepare_input( ...@@ -540,7 +540,57 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(), Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
_ => inputs.clone(), _ => inputs.clone(),
}; };
......
...@@ -68,6 +68,7 @@ try: ...@@ -68,6 +68,7 @@ try:
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
...@@ -579,6 +580,18 @@ def get_model( ...@@ -579,6 +580,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:
......
...@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module): ...@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = MistralModel( self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model", prefix=name if not prefix else f"{prefix}.{name}",
config=config, config=config,
weights=weights, weights=weights,
) )
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head", # TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
......
...@@ -23,6 +23,10 @@ from torch import nn ...@@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module): class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
...@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds return inputs_embeds
def forward( def forward(
...@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids)
......
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
...@@ -511,6 +511,21 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -511,6 +511,21 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
if cu_seqlen_prefill is None:
logits, speculative_logits = self.compiled_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
else:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
......
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
import torch import torch
from typing import Optional from typing import Optional, Tuple
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
...@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM): ...@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): ...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int: def get_number_of_features(height: int, width: int, config) -> int:
# From config # From config
# Hardcoded for CLIP for now # Hardcoded for CLIP for now
...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int: ...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
) )
unpadded_features, newline_features = get_unpadded_features(
height_of_patch = math.ceil(height / width * npatches) height, width, npatches, num_patch_height, num_patch_width
)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image # The base patch covers the entire image
base_features = npatches**2 base_features = npatches**2
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image: ...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
@classmethod @classmethod
...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests: for r in requests:
chunks = split(r.inputs) chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:" "Cannot process input image not starting with data:"
) )
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0] full_text += image_text_replacement(image_input, config, image_id)
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")
...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_inputs = { image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat( "pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0 [img["pixel_values"] for img in image_inputs], dim=0
), ),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
} }
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device) batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward( def forward(
self, batch: VlmCausalLMBatch self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment