"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "ecc3028c13d3f7ad1e8f09206f9bca425e99da8e"
Unverified Commit 4634b00c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding Llava-Next (Llava 1.6) with full support. (#1709)

# What does this PR do?

- Changed all models to extract `embed_tokens` in order to enable llava
to separately call the embeddings and the core model layers.
- Added VlmCausalLM to inherit from FlashMistral in order to be
maximally supported. The only added logics sits on top and parses images
into pixel values, preallocates input_ids space for the image
embeddings, and passes them for the model.
- Added Clip for the vision tower.
- Didn't add flash for the vision tower since there's no padding anyway.
- Added heuristic (potentially incomplete) to calculate number of
features *before* calculating the clip patches (allows for easier logic
reuse of the LLM under the hood).


Still needs to be done:

- [x] Implement the image parsing in the controller side, to avoid
downloading n times per TP shard and also refusing requests too large
early and avoid issues where the truncation actually truncates the
image.
- [ ] Make sure it works with quantization properly.
- [x] Make sure it works with TP>1



<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 106d8ee8
...@@ -285,9 +285,8 @@ class MistralMLP(nn.Module): ...@@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
...@@ -343,27 +342,24 @@ class MistralLayer(nn.Module): ...@@ -343,27 +342,24 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MistralLayer( MistralLayer(
layer_id, prefix=f"{prefix}.layers.{layer_id}",
config, config=config,
weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module): ...@@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
...@@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module): ...@@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ):
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
...@@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module): ...@@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
...@@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
......
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits
...@@ -107,23 +107,27 @@ class FlashCausalLMBatch(Batch): ...@@ -107,23 +107,27 @@ class FlashCausalLMBatch(Batch):
) )
@classmethod @classmethod
def from_pb( def batch_tokenized_inputs(cls, requests, tokenizer):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
for r in pb.requests: for r in requests:
batch_inputs.append(r.inputs) batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
return batch_tokenized_inputs
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
......
...@@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM): ...@@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,
......
...@@ -6,8 +6,7 @@ import numpy as np ...@@ -6,8 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
...@@ -66,17 +65,19 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -66,17 +65,19 @@ class FlashMistralBatch(FlashCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows() batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( @classmethod
batch_inputs, truncation=True, max_length=max_truncation def from_tokenized(
)["input_ids"] cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows()
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
...@@ -301,14 +302,15 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -301,14 +302,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
config_cls,
model_cls, model_cls,
model_id: str, model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -317,22 +319,13 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -317,22 +319,13 @@ class BaseFlashMistral(FlashCausalLM):
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
try: tokenizer = tokenizer_class.from_pretrained(
tokenizer = LlamaTokenizerFast.from_pretrained( model_id,
model_id, revision=revision,
revision=revision, padding_side="left",
padding_side="left", truncation_side="left",
truncation_side="left", trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, )
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
...@@ -341,10 +334,12 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -341,10 +334,12 @@ class BaseFlashMistral(FlashCausalLM):
config.use_medusa = use_medusa config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if getattr(config, "sliding_window", None) is not None:
set_sliding_window( set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
) )
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
...@@ -353,17 +348,19 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -353,17 +348,19 @@ class BaseFlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_cls(config, weights) prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {} self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=num_layers,
num_kv_heads=model.model.num_key_value_heads, num_kv_heads=num_kv_heads,
head_size=model.model.head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
...@@ -371,6 +368,16 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -371,6 +368,16 @@ class BaseFlashMistral(FlashCausalLM):
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
def max_past(self) -> int:
return self.model.max_past
@property @property
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch return FlashMistralBatch
...@@ -485,11 +492,11 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -485,11 +492,11 @@ class BaseFlashMistral(FlashCausalLM):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.model.max_past is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.model.max_past, max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs
......
import torch import torch
import torch
import time import time
from dataclasses import dataclass from dataclasses import dataclass
...@@ -20,29 +21,13 @@ from text_generation_server.models.types import ( ...@@ -20,29 +21,13 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split
import re import re
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string):
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append(string[cursor:start])
parts.append(pattern.group(1))
cursor = pattern.end()
if cursor != len(string):
parts.append(string[cursor:])
return parts
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch): ...@@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
raise NotImplementedError
@classmethod
def from_pb_processor(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack processor: ProcessorMixin, # Hack
config,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "IdeficsCausalLMBatch": ) -> "IdeficsCausalLMBatch":
...@@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch): ...@@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
# TODO Check impact on idefics
prompts = [] prompts = []
for inp in inputs: for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL # Each input is encoded into a list, where each element of this input list is either a string or a URL
prompts.append(split(inp)) prompt = []
for chunk in split(inp):
prompt.append(chunk["content"])
prompts.append(prompt)
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL # a/ takes care of fetching images from the URL
...@@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch): ...@@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch):
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token # TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
...@@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch):
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs["pixel_values"] pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None image_hidden_states = None
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros( attention_mask = input_ids.new_zeros(
...@@ -165,16 +166,19 @@ class IdeficsCausalLMBatch(Batch): ...@@ -165,16 +166,19 @@ class IdeficsCausalLMBatch(Batch):
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask # Do the same for image_attention_mask
image_attention_mask = input_ids.new_zeros( if pixel_values is None:
( image_attention_mask = None
pb.size, else:
max_input_length + padding_right_offset, image_attention_mask = input_ids.new_zeros(
tokenized_inputs["pixel_values"].size(1), (
pb.size,
max_input_length + padding_right_offset,
pixel_values.size(1),
)
) )
) image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ "image_attention_mask"
"image_attention_mask" ]
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
...@@ -677,19 +681,22 @@ class IdeficsCausalLM(Model): ...@@ -677,19 +681,22 @@ class IdeficsCausalLM(Model):
start = time.time_ns() start = time.time_ns()
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1: if batch.image_attention_mask is None:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), image_attention_mask = None
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else: else:
image_attention_mask = batch.image_attention_mask[ if batch.input_ids.size(1) == 1:
:, : -batch.padding_right_offset # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
] # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, speculative_logits, past, image_hidden_states = self.forward( logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,
......
import torch
from typing import Optional
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
import re
import torch
import math
from PIL import Image
from io import BytesIO
import base64
from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string) -> List[Dict[str, str]]:
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append({"type": "text", "content": string[cursor:start]})
parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end()
if cursor != len(string):
parts.append({"type": "text", "content": string[cursor:]})
return parts
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
def load_data_uri(image_uri: str) -> Image.Image:
image_uri = image_uri.split(",")[-1]
content = base64.b64decode(image_uri)
image = Image.open(BytesIO(content))
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.pixel_values = None
batch.image_sizes = None
return batch
class VlmCausalLM(BaseFlashMistral):
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
...@@ -13,6 +13,7 @@ from typing import List, Optional ...@@ -13,6 +13,7 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
...@@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
except ImportError: except ImportError:
pass pass
if ( if self.model.batch_type in {
self.model.batch_type == IdeficsCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
batch = self.model.batch_type.from_pb( }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model.model.config,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )
...@@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
if ( if self.model.batch_type in {
self.model.batch_type == IdeficsCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
batch = self.model.batch_type.from_pb( }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model.model.config,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment