Commit 112bf76b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1826 canceled with stages
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from PIL import Image
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoModelForCausalLM,
MixtralConfig,
MixtralForCausalLM,
MixtralModel,
)
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
def load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: torch.Tensor = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
def custom_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
MixtralForCausalLM.forward = custom_forward
class VITAMixtralConfig(MixtralConfig):
model_type = "vita-mixtral"
class VITAMixtralModel(VITAMetaModel, MixtralModel):
config_class = VITAMixtralConfig
def __init__(self, config: MixtralConfig):
super(VITAMixtralModel, self).__init__(config)
class VITAMixtralForCausalLM(MixtralForCausalLM, VITAMetaForCausalLM):
config_class = VITAMixtralConfig
def __init__(self, config):
super(MixtralForCausalLM, self).__init__(config)
self.model = VITAMixtralModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
audios: Optional[dict] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images, audios
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
def prepare_inputs_for_generation_original(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
output_router_logits=False,
**kwargs,
):
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
}
)
return model_inputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
attention_mask=None,
output_router_logits=False,
**kwargs,
):
images = kwargs.pop("images", None)
audios = kwargs.pop("audios", None)
_inputs = self.prepare_inputs_for_generation_original(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_router_logits=output_router_logits,
**kwargs,
)
if images is not None:
_inputs["images"] = images
if audios is not None:
_inputs["audios"] = audios
return _inputs
def expand2square(self, pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(self, images, model_cfg):
vision_tower = self.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
image_processor = vision_tower.image_processor
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = self.expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean)
)
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
AutoConfig.register("vita-mixtral", VITAMixtralConfig)
AutoModelForCausalLM.register(VITAMixtralConfig, VITAMixtralForCausalLM)
import os
import yaml
from .clip.clip_encoder import CLIPVisionTower
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
from .internvit.internvit_encoder import InternViTVisionTower
from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
from .whale.init_model import init_model
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
)
use_s2 = getattr(vision_tower_cfg, "use_s2", False)
if "sig" in vision_tower.lower():
if use_s2:
return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "eva" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for EVA-CLIP")
else:
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "clip" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for CLIP")
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "internvit" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for InternViT")
else:
return InternViTVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f"Unknown vision tower: {vision_tower}")
def build_audio_encoder(audio_encoder_config, **kwargs):
with open(audio_encoder_config.mm_audio_encoder + "/train.yaml", "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
configs["cmvn_file"] = audio_encoder_config.mm_audio_encoder + "/global_cmvn"
# configs['model_conf']['freeze_encoder'] = audio_encoder_config.freeze_audio_encoder
# configs['model_conf']['freeze_adpter'] = audio_encoder_config.freeze_audio_encoder_adapter
configs["model_conf"]["freeze_encoder"] = getattr(
audio_encoder_config, "freeze_audio_encoder", True
)
configs["model_conf"]["freeze_adpter"] = getattr(
audio_encoder_config, "freeze_audio_encoder_adapter", True
)
return init_model(configs)
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = -2
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
image_features = image_features[:, 1:]
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
import torch
import torch.nn as nn
from .eva_clip_processors import EvaClipImageTrainProcessor
from .eva_vit import Eva2LargePlusEncoder
class EvaClipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_path = vision_tower
self.config = VisionTowerConfig()
if not delay_load:
self.load_model()
else:
self.cfg_only = self.config
def load_model(self):
self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
).to(image.dtype)
image_features.append(image_feature)
else:
image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(
images.dtype
)
return image_features
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class VisionTowerConfig:
def __init__(self):
self.image_size = 336
self.patch_size = 14
self.hidden_size = 1024
"""
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
"""
from PIL import Image
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import convert_to_rgb
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
class EvaClipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
self.normalize = transforms.Normalize(self.mean, self.std)
@property
def image_mean(self):
return self.mean
class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
convert_to_rgb,
transforms.Resize(
image_size,
interpolation=InterpolationMode.BICUBIC,
),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
]
)
self.image_size = image_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
assert isinstance(images, list)
transformed_images = [self.transform(image).numpy() for image in images]
data = {"pixel_values": transformed_images}
return BatchFeature(data=data, tensor_type=return_tensors)
def __call__(self, item):
return self.transform(item)
@property
def crop_size(self):
return {"height": self.image_size, "width": self.image_size}
This diff is collapsed.
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
Number of color channels in the input images (e.g., 3 for RGB).
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 3200):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 12800):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention mechanism.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
drop_path_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for stochastic depth.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = "intern_vit_6b"
def __init__(
self,
num_channels=3,
patch_size=14,
image_size=224,
qkv_bias=False,
hidden_size=3200,
num_attention_heads=25,
intermediate_size=12800,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act="gelu",
norm_type="rms_norm",
layer_norm_eps=1e-6,
dropout=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.drop_path_rate = drop_path_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.norm_type = norm_type
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "vision_config" in config_dict:
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try: # v1
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import (
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
)
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(
self,
qkv,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
)
output = flash_attn_unpadded_qkvpacked_func(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
),
"b s (h d) -> b s h d",
h=nheads,
)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
)
return output, None
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from .modeling_intern_vit import InternVisionModel
class InternViTVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = -1
self.scale_pix_shuffle = 0.5
if not delay_load:
self.load_model()
else:
self.cfg_only = AutoConfig.from_pretrained(
self.vision_tower_name, trust_remote_code=True
)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
# self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
self.vision_tower = InternVisionModel.from_pretrained(
self.vision_tower_name, trust_remote_code=True
)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
image_features = image_features[:, 1:]
return image_features
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))
)
x = x.permute(0, 2, 1, 3).contiguous()
return x
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
h = w = int(image_features.shape[1] ** 0.5)
assert image_features.shape[1] == h * w
image_features = image_features.reshape(image_features.shape[0], h, w, -1)
image_features = self.pixel_shuffle(image_features * self.scale_pix_shuffle)
image_features = image_features.reshape(
image_features.shape[0], -1, image_features.shape[-1]
)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size * (int(1 / self.scale_pix_shuffle) ** 2)
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from timm.models.layers import DropPath
from .configuration_intern_vit import InternVisionConfig
try:
from .flash_attention import FlashAttention
has_flash_attn = True
except:
print("FlashAttention is not installed.")
has_flash_attn = False
logger = logging.get_logger(__name__)
class InternRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
try:
from apex.normalization import FusedRMSNorm
InternRMSNorm = FusedRMSNorm # noqa
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
except ImportError:
# using the normal InternRMSNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
pass
NORM2FN = {
"rms_norm": InternRMSNorm,
"layer_norm": nn.LayerNorm,
}
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim),
)
self.patch_embedding = nn.Conv2d(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W):
target_dtype = pos_embed.dtype
pos_embed = (
pos_embed.float()
.reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
.permute(0, 3, 1, 2)
)
pos_embed = (
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
.reshape(1, -1, H * W)
.permute(0, 2, 1)
.to(target_dtype)
)
return pos_embed
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat(
[
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
],
dim=1,
)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
class InternAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_flash_attn = config.use_flash_attn and has_flash_attn
if config.use_flash_attn and not has_flash_attn:
print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
if self.use_flash_attn:
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def _naive_attn(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
)
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
outs = self.proj_drop(outs)
return outs
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = (
self._naive_attn(hidden_states)
if not self.use_flash_attn
else self._flash_attn(hidden_states)
)
return x
class InternMLP(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.act = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class InternVisionEncoderLayer(nn.Module):
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config)
self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
"""
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states = hidden_states + self.drop_path1(
self.attn(self.norm1(hidden_states)) * self.ls1
)
hidden_states = hidden_states + self.drop_path2(
self.mlp(self.norm2(hidden_states)) * self.ls2
)
return hidden_states
class InternVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`InternEncoderLayer`].
Args:
config (`InternConfig`):
The corresponding vision configuration for the `InternEncoder`.
"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
self.layers = nn.ModuleList(
[InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = True
def forward(
self,
inputs_embeds,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
else:
layer_outputs = encoder_layer(
hidden_states,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
class InternVisionModel(PreTrainedModel):
main_input_name = "pixel_values"
config_class = InternVisionConfig
_no_split_modules = ["InternVisionEncoderLayer"]
def __init__(self, config: InternVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(config)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = (
pos_emb[:, 1:, :]
.reshape(1, old_size // patch_size, old_size // patch_size, -1)
.permute(0, 3, 1, 2)
)
pos_emb = F.interpolate(
pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False
)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
logger.info("Resized position embeddings from {} to {}".format(old_size, new_size))
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError("You have to specify pixel_values or pixel_embeds")
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
if len(pixel_values.shape) == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment