Unverified Commit 4634b00c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding Llava-Next (Llava 1.6) with full support. (#1709)

# What does this PR do?

- Changed all models to extract `embed_tokens` in order to enable llava
to separately call the embeddings and the core model layers.
- Added VlmCausalLM to inherit from FlashMistral in order to be
maximally supported. The only added logics sits on top and parses images
into pixel values, preallocates input_ids space for the image
embeddings, and passes them for the model.
- Added Clip for the vision tower.
- Didn't add flash for the vision tower since there's no padding anyway.
- Added heuristic (potentially incomplete) to calculate number of
features *before* calculating the clip patches (allows for easier logic
reuse of the LLM under the hood).


Still needs to be done:

- [x] Implement the image parsing in the controller side, to avoid
downloading n times per TP shard and also refusing requests too large
early and avoid issues where the truncation actually truncates the
image.
- [ ] Make sure it works with quantization properly.
- [x] Make sure it works with TP>1



<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 106d8ee8
This diff is collapsed.
......@@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Phi](https://huggingface.co/microsoft/phi-2)
- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal)
- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
......
......@@ -277,6 +277,8 @@ def launcher(event_loop):
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
......@@ -314,6 +316,12 @@ def launcher(event_loop):
args.append(revision)
if trust_remote_code:
args.append("--trust-remote-code")
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
env["LOG_LEVEL"] = "info,text_generation_router=debug"
......@@ -347,6 +355,8 @@ def launcher(event_loop):
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
......@@ -367,6 +377,12 @@ def launcher(event_loop):
args.append(revision)
if trust_remote_code:
args.append("--trust-remote-code")
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
client = docker.from_env()
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 6,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -10.5,
"text": "Test"
},
{
"id": 2159,
"logprob": -12.140625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.0654297,
"special": false,
"text": "\n"
},
{
"id": 1014,
"logprob": -2.7460938,
"special": false,
"text": "The"
},
{
"id": 6032,
"logprob": -1.359375,
"special": false,
"text": " purpose"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 456,
"logprob": 0.0,
"special": false,
"text": " this"
},
{
"id": 1369,
"logprob": -0.40063477,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "Test request\nThe purpose of this test"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.00756073,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.20117188,
"special": false,
"text": "\n"
},
{
"id": 16114,
"logprob": -1.2597656,
"special": false,
"text": "Once"
},
{
"id": 3714,
"logprob": -0.20825195,
"special": false,
"text": " upon"
},
{
"id": 264,
"logprob": -0.00178051,
"special": false,
"text": " a"
},
{
"id": 727,
"logprob": -0.011955261,
"special": false,
"text": " time"
},
{
"id": 28725,
"logprob": -0.17541504,
"special": false,
"text": ","
},
{
"id": 736,
"logprob": -0.91308594,
"special": false,
"text": " there"
},
{
"id": 403,
"logprob": -0.058410645,
"special": false,
"text": " was"
},
{
"id": 264,
"logprob": -0.009689331,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "\n\nOnce upon a time, there was a"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 9,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 0,
......@@ -14,7 +14,7 @@
"tokens": [
{
"id": 16017,
"logprob": -0.30908203,
"logprob": 0.0,
"special": false,
"text": " blue"
},
......@@ -26,39 +26,45 @@
},
{
"id": 259,
"logprob": -0.28271484,
"logprob": -0.4716797,
"special": false,
"text": " "
},
{
"id": 15484,
"logprob": -1.7929688,
"id": 261,
"logprob": -0.044677734,
"special": false,
"text": "appear"
"text": ","
},
{
"id": 345,
"logprob": -0.8935547,
"id": 35622,
"logprob": -0.79589844,
"special": false,
"text": "ed"
"text": " cloud"
},
{
"id": 281,
"logprob": 0.0,
"id": 263,
"logprob": -1.2958984,
"special": false,
"text": " in"
"text": "s"
},
{
"id": 287,
"id": 305,
"logprob": 0.0,
"special": false,
"text": " the"
"text": " and"
},
{
"id": 20495,
"logprob": -0.32299805,
"id": 35622,
"logprob": -1.1630859,
"special": false,
"text": " sky"
"text": " cloud"
},
{
"id": 263,
"logprob": 0.0,
"special": false,
"text": "s"
},
{
"id": 1,
......@@ -66,7 +72,8 @@
"special": true,
"text": "</s>"
}
]
],
"top_tokens": null
},
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
}
......@@ -33,6 +33,9 @@ async def test_idefics(idefics, response_snapshot):
)
assert response.details.generated_tokens == 10
assert (
response.generated_text == " \nAssistant: A rooster stands"
), f"{repr(response.generated_text)}"
assert response == response_snapshot
......@@ -48,6 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses]
assert (
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
......
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_llava_next_handle(launcher):
with launcher(
"llava-hf/llava-v1.6-mistral-7b-hf",
num_shard=4,
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llava_next(flash_llava_next_handle):
await flash_llava_next_handle.health(300)
return flash_llava_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
chicken = get_chicken()
response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
)
assert (
response.generated_text == "\n\nOnce upon a time, there was a"
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
response = await flash_llava_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == "\n\nOnce upon a time, there was a"
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot
......@@ -45,7 +45,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 9
assert response.details.generated_tokens == 10
assert response == response_snapshot
......
......@@ -44,10 +44,12 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = "0.22.0"
[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
......
......@@ -112,10 +112,15 @@ impl Client {
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![](");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
inputs,
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
......
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
text_config: TextConfig,
vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>,
}
fn get_anyres_image_grid_shape(
height: usize,
width: usize,
grid_pinpoints: &[(usize, usize)],
patch_size: usize,
) -> (usize, usize) {
let (height, width) = select_best_resolution(height, width, grid_pinpoints);
(height / patch_size, width / patch_size)
}
/// Selects the best resolution from a list of possible resolutions based on the original size.
/// This is done by calculating the effective and wasted resolution for each possible resolution.
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
fn select_best_resolution(
original_height: usize,
original_width: usize,
possible_resolutions: &[(usize, usize)],
) -> (usize, usize) {
let mut best_fit = None;
let mut max_effective_resolution = 0;
let mut min_wasted_resolution = f32::NEG_INFINITY;
for (height, width) in possible_resolutions {
let wscale = *width as f32 / original_width as f32;
let hscale = *height as f32 / original_height as f32;
// f32 partial ord.
let scale = if wscale > hscale { hscale } else { wscale };
let downscaled_width = (*width as f32 * scale) as usize;
let downscaled_height = (*height as f32 * scale) as usize;
let effective_resolution = std::cmp::min(
downscaled_width * downscaled_height,
original_width * original_height,
);
let wasted_resolution = (width * height) - effective_resolution;
if effective_resolution > max_effective_resolution
|| (effective_resolution == max_effective_resolution
&& (wasted_resolution as f32) < min_wasted_resolution)
{
max_effective_resolution = effective_resolution;
min_wasted_resolution = wasted_resolution as f32;
best_fit = Some((*height, *width));
}
}
best_fit.unwrap_or((original_height, original_width))
}
impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size;
let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image
let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
image_size: usize,
patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Ssm,
GptBigcode,
Santacoder,
Bloom,
Mpt,
GptNeox,
Phi,
#[serde(rename = "phi-msft")]
PhiMsft,
Llama,
Baichuan,
Gemma,
Cohere,
Drbx,
Falcon,
Mixtral,
Starcoder2,
Qwen2,
Opt,
T5,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
image_size: usize,
patch_size: usize,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_llava_next_features() {
let config = LlavaNext {
text_config: TextConfig {},
vision_config: VisionConfig {
image_size: 336,
patch_size: 14,
},
image_grid_pinpoints: vec![
(336, 672),
(672, 336),
(672, 672),
(1008, 336),
(336, 1008),
],
};
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
}
}
pub mod config;
mod health;
/// Text Generation Inference Webserver
mod infer;
......
......@@ -13,6 +13,7 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use thiserror::Error;
use tokenizers::Tokenizer;
......@@ -191,15 +192,19 @@ async fn main() -> Result<(), RouterError> {
};
// Load tokenizer and model info
let (tokenizer, model_info) = if local_model {
let (tokenizer, model_info, config) = if local_model {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
let model_info = HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
};
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
.ok()
.as_ref()
.and_then(|c| serde_json::from_str(c).ok());
(tokenizer, model_info)
(tokenizer, model_info, config)
} else if let Some(api) = api.clone() {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
......@@ -212,6 +217,19 @@ async fn main() -> Result<(), RouterError> {
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
......@@ -221,7 +239,7 @@ async fn main() -> Result<(), RouterError> {
}
});
(tokenizer, model_info)
(tokenizer, model_info, config)
} else {
// No API and no local model
return Err(RouterError::ArgumentValidation(
......@@ -229,6 +247,8 @@ async fn main() -> Result<(), RouterError> {
));
};
tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
......@@ -363,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_size,
sharded_client,
tokenizer,
config,
validation_workers,
addr,
cors_allow_origin,
......
use crate::config::Config;
/// HTTP Server logic
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
......@@ -164,7 +165,8 @@ async fn generate(
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.inputs);
// Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
let compute_characters = req.inputs.chars().count();
let mut add_prompt = None;
......@@ -1154,6 +1156,7 @@ pub async fn run(
max_batch_size: Option<usize>,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
......@@ -1236,6 +1239,7 @@ pub async fn run(
let validation = Validation::new(
validation_workers,
tokenizer,
config,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
......
use crate::config::Config;
/// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
// use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
......@@ -34,6 +38,7 @@ impl Validation {
pub(crate) fn new(
workers: usize,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
......@@ -50,12 +55,13 @@ impl Validation {
// Create workers
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let config_clone = config.clone();
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender);
// Spawn worker
tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
});
}
......@@ -408,49 +414,138 @@ async fn round_robin_task(
}
/// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
fn tokenizer_worker(
tokenizer: Tokenizer,
config: Option<Config>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests
let is_multimodal = {
let vocab = tokenizer.get_vocab(true);
vocab.contains_key("<image>")
};
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
.send(prepare_input(inputs, truncate, &tokenizer, &config))
.unwrap_or(())
})
}
}
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
match mimetype {
"image/png" => Some(ImageFormat::Png),
"image/jpeg" => Some(ImageFormat::Jpeg),
"image/jpg" => Some(ImageFormat::Jpeg),
"image/gif" => Some(ImageFormat::Gif),
"image/webp" => Some(ImageFormat::WebP),
"image/tiff" => Some(ImageFormat::Tiff),
// "image/pnm"=>Some(ImageFormat::Pnm),
// "image/tga"=>Some(ImageFormat::Tga),
// "image/dds"=>Some(ImageFormat::Dds),
// "image/bmp"=>Some(ImageFormat::Bmp),
// "image/ico"=>Some(ImageFormat::Ico),
// "image/x-exr"=>Some(ImageFormat::OpenExr),
_ => None,
}
}
fn format_to_mimetype(format: ImageFormat) -> String {
match format {
ImageFormat::Png => "image/png",
ImageFormat::Jpeg => "image/jpeg",
ImageFormat::Gif => "image/gif",
ImageFormat::WebP => "image/webp",
ImageFormat::Tiff => "image/tiff",
_ => "application/octet-stream",
}
.to_string()
}
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
let format = image::guess_format(&data)?;
// TODO Remove this clone
let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
let mimetype = format_to_mimetype(format);
let encoded = STANDARD.encode(data);
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
Ok((data_uri, height, width))
} else if input.starts_with("![](data:") {
// Remove ![](....)
let content = &input["![](data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidImageContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidImageContent(content.to_string()));
}
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
let img = if let Some(format) = format_from_mimetype(mimetype) {
ImageReader::with_format(Cursor::new(data), format).decode()?
} else {
ImageReader::new(Cursor::new(data))
.with_guessed_format()
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
.decode()?
};
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
Ok((input.to_string(), height, width))
} else {
Err(ValidationError::InvalidImageContent(input.to_string()))
}
}
/// Get input length and optionally truncate it
fn prepare_input(
mut inputs: String,
truncate: Option<usize>,
_truncate: Option<usize>,
tokenizer: &Tokenizer,
is_multimodal: bool,
config: &Option<Config>,
) -> Result<(tokenizers::Encoding, String), ValidationError> {
let simplified_query = if is_multimodal {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
RE.replace_all(&inputs, "<image>").into()
} else {
inputs.clone()
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let tokenizer_query = match config {
Some(Config::LlavaNext(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
_ => inputs.clone(),
};
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(simplified_query, true)
let encoding = tokenizer
.encode(tokenizer_query, true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate
if let Some(truncate) = truncate {
if truncate < encoding.len() && !is_multimodal {
encoding.truncate(truncate, 0, TruncationDirection::Left);
inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
}
}
Ok((encoding, inputs))
}
......@@ -523,6 +618,16 @@ pub enum ValidationError {
Grammar,
#[error("grammar is not valid: {0}")]
InvalidGrammar(String),
#[error("base64 encoding is invalid: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid image: {0}")]
InvalidImage(#[from] image::ImageError),
#[error("invalid integer: {0}")]
InvalidInt(#[from] core::num::TryFromIntError),
#[error("invalid image content: {0}")]
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
}
#[cfg(test)]
......@@ -541,9 +646,11 @@ mod tests {
let max_total_tokens = 6;
let workers = 1;
let disable_grammar_support = true;
let config = None;
let validation = Validation::new(
workers,
tokenizer,
config,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
......@@ -572,9 +679,11 @@ mod tests {
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = None;
let validation = Validation::new(
workers,
tokenizer,
config,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
......@@ -603,9 +712,11 @@ mod tests {
let max_total_tokens = 6;
let workers = 1;
let disable_grammar_support = true;
let config = None;
let validation = Validation::new(
workers,
tokenizer,
config,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
......@@ -639,9 +750,11 @@ mod tests {
let max_total_tokens = 106;
let workers = 1;
let disable_grammar_support = true;
let config = None;
let validation = Validation::new(
workers,
tokenizer,
config,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
......@@ -704,9 +817,11 @@ mod tests {
let max_total_tokens = 106;
let workers = 1;
let disable_grammar_support = true;
let config = None;
let validation = Validation::new(
workers,
tokenizer,
config,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
......
......@@ -67,6 +67,7 @@ try:
FlashSantacoderSharded,
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
......@@ -579,6 +580,19 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:
return LlavaNext(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq":
......
This diff is collapsed.
......@@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
......@@ -337,27 +336,30 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
layer_id,
config,
weights,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.gradient_checkpointing = False
......@@ -368,7 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
......@@ -376,8 +378,10 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
......@@ -406,13 +410,19 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
self.model = FlashLlamaModel(config, weights)
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
......@@ -426,10 +436,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_ids,
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
......@@ -437,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment