Unverified Commit 0a995d54 authored by Congcong Chen's avatar Congcong Chen Committed by GitHub
Browse files

[Model] New model support for Phi-4-multimodal-instruct (#14119)

parent ade3f7d9
......@@ -410,7 +410,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `Phi3ForCausalLM`
* Phi-4, Phi-3
* `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
* `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
* ✅︎
* ✅︎
- * `Phi3SmallForCausalLM`
......@@ -856,6 +856,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Phi4MMForCausalLM`
* Phi-4-multimodal
* T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup>
* `microsoft/Phi-4-multimodal-instruct`, etc.
* ✅︎
*
*
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I<sup>+</sup>
......
......@@ -37,3 +37,4 @@ depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
python-json-logger # Used by logging as per examples/other/logging_configuration.md
scipy # Required for phi-4-multimodal-instruct
\ No newline at end of file
......@@ -272,6 +272,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True),
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
......
......@@ -2284,9 +2284,9 @@ class LoRAConfig:
return hash_str
def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256)
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
......
......@@ -395,6 +395,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
......@@ -424,6 +426,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "audio":
if model_type == "ultravox":
return "<|audio|>"
if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
......
# SPDX-License-Identifier: Apache-2.0
import math
import re
from functools import lru_cache
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import scipy.signal
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
from transformers.utils import logging
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs.data import TokenInputs, token_inputs
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
# <|endoftext11|>
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011
_AUDIO_MAX_SOUNDFILE_SIZE = 241_000
DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz
DYNAMIC_HD = 16
AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>"
IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>"
SIGLIP_NAME = "siglip-so400m-patch14-448"
VISION_ENCODER_TO_PROCESSING_CONFIG = {
'siglip-so400m-patch14-448': {
'dynamic_hd': 16,
'vit_image_size': 448,
'vit_patch_size': 14,
'token_compression_factor': 2,
},
}
logger = logging.get_logger(__name__)
# This is a workaround to prevent text (user input) + audio + image
# from being used in the same prompt.
# It includes token ids for "/n" and tokens in added_tokens_decoder
# from the tokenizer_confg.json file.
NON_USER_INPUT_TOKENS = {
198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022,
200023, 200024, 200025, 200026, 200027, 200028
}
def get_max_dummy_image(ctx: InputContext):
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
max_side = vit_image_size * dynamic_hd_size
dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side)
return dummy_image
# image token length
def get_max_phi4mm_image_tokens(ctx: InputContext):
dummy_image = get_max_dummy_image(ctx)
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
token_compression_factor = prepro_config['token_compression_factor']
image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size,
vit_image_size,
vit_patch_size,
token_compression_factor)
return image_num_tokens
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def _find_target_aspect_ratio(image, image_size, max_num, min_num):
orig_width, orig_height = image.size
w_crop_num = math.ceil(orig_width / float(image_size))
h_crop_num = math.ceil(orig_height / float(image_size))
if w_crop_num * h_crop_num > max_num:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
logger.debug("target_aspect_ratio: %s", target_aspect_ratio)
else:
target_width = image_size * w_crop_num
target_height = image_size * h_crop_num
target_aspect_ratio = (w_crop_num, h_crop_num)
return target_aspect_ratio, target_height, target_width
def _get_padding_size(image, target_height, target_width):
orig_width, orig_height = image.size
ratio_width = target_width / orig_width
ratio_height = target_height / orig_height
if ratio_width < ratio_height:
padding_width = 0
padding_height = target_height - int(orig_height * ratio_width)
else:
padding_width = target_width - int(orig_width * ratio_height)
padding_height = 0
return padding_height, padding_width
def dynamic_preprocess(image,
min_num=1,
max_num=12,
image_size=384,
mask_size=27):
target_aspect_ratio, target_height, target_width =\
_find_target_aspect_ratio(
image, image_size, max_num, min_num)
padding_height, padding_width = _get_padding_size(image, target_height,
target_width)
# Calculate the ratio
orig_width, orig_height = image.size
ratio_width = target_width / orig_width
ratio_height = target_height / orig_height
if ratio_width < ratio_height:
new_size = (target_width, int(orig_height * ratio_width))
else:
new_size = (int(orig_width * ratio_height), target_height)
attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]),
int(mask_size * target_aspect_ratio[0])))
if padding_width >= 14:
attention_mask[:, -math.floor(padding_width / 14):] = 0
if padding_height >= 14:
attention_mask[-math.floor(padding_height / 14):, :] = 0
assert attention_mask.sum(
) > 0, f'attention mask is empty {attention_mask}'
if min(new_size[1], target_height) < 10 or min(new_size[0],
target_width) < 10:
raise ValueError(f'the aspect ratio is very extreme {new_size}')
image = T.functional.resize(
image,
[new_size[1], new_size[0]],
)
resized_img = T.functional.pad(image,
[0, 0, padding_width, padding_height],
fill=[255, 255, 255])
return resized_img, attention_mask
def pad_to_max_num_crops(images, max_crops=5):
"""
images: B x 3 x H x W, B<=max_crops
"""
B, _, H, W = images.shape
if max_crops > B:
pad = torch.zeros(max_crops - B,
3,
H,
W,
dtype=images.dtype,
device=images.device)
images = torch.cat([images, pad], dim=0)
return images
def pad_mask_to_max_num_crops(masks, max_crops=5):
B, H, W = masks.shape
if max_crops > B:
pad = torch.ones(max_crops - B,
H,
W,
dtype=masks.dtype,
device=masks.device)
masks = torch.cat([masks, pad], dim=0)
return masks
def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
# Basic settings.
img_processor = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Dynamic HD
base_resolution = vit_resolution
images = [image.convert('RGB') for image in images]
# cover 384 and 448 resolution
mask_resolution = base_resolution // vit_patch_size
elems, image_attention_masks = [], []
for im in images:
elem, attention_mask = dynamic_preprocess(im,
max_num=dynamic_hd_size,
image_size=base_resolution,
mask_size=mask_resolution)
elems.append(elem)
image_attention_masks.append(attention_mask)
hd_images = [img_processor(im) for im in elems]
global_image = [
torch.nn.functional.interpolate(
im.unsqueeze(0).float(),
size=(base_resolution, base_resolution),
mode='bicubic',
).to(im.dtype) for im in hd_images
]
shapes = [[im.size(1), im.size(2)] for im in hd_images]
mask_shapes = [[mask.size(0), mask.size(1)]
for mask in image_attention_masks]
global_attention_mask = [
torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images
]
hd_images_reshape = [
im.reshape(1, 3, h // base_resolution, base_resolution,
w // base_resolution, base_resolution).permute(
0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution,
base_resolution).contiguous()
for im, (h, w) in zip(hd_images, shapes)
]
attention_masks_reshape = [
mask.reshape(1, h // mask_resolution, mask_resolution,
w // mask_resolution, mask_resolution).permute(
0, 1, 3, 2, 4).reshape(-1, mask_resolution,
mask_resolution).contiguous()
for mask, (h, w) in zip(image_attention_masks, mask_shapes)
]
# NOTE token compression is hard coded here, and odd numbers seems to fail
downsample_attention_masks = [
mask[:, 0::2,
0::2].reshape(1, h // mask_resolution, w // mask_resolution,
mask_resolution // 2 + mask_resolution % 2,
mask_resolution // 2 + mask_resolution % 2).permute(
0, 1, 3, 2, 4)
for mask, (h, w) in zip(attention_masks_reshape, mask_shapes)
]
downsample_attention_masks = [
mask.reshape(mask.size(1) * mask.size(2),
mask.size(3) * mask.size(4))
for mask in downsample_attention_masks
]
# NOTE hard coded number of tokens
num_img_tokens = [
256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16
for mask in downsample_attention_masks
]
hd_images_reshape = [
torch.cat([_global_image] + [_im], dim=0)
for _global_image, _im in zip(global_image, hd_images_reshape)
]
hd_masks_reshape = [
torch.cat([_global_mask] + [_mask],
dim=0) for _global_mask, _mask in zip(
global_attention_mask, attention_masks_reshape)
]
max_crops = max([img.size(0) for img in hd_images_reshape])
image_transformed = [
pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape
]
image_transformed = torch.stack(image_transformed, dim=0)
mask_transformed = [
pad_mask_to_max_num_crops(mask, max_crops) \
for mask in hd_masks_reshape
]
mask_transformed = torch.stack(mask_transformed, dim=0)
returned_input_image_embeds = image_transformed
returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
returned_image_attention_mask = mask_transformed
returned_num_img_tokens = num_img_tokens
data = {
"pixel_values": returned_input_image_embeds,
"image_sizes": returned_image_sizes,
"image_attention_mask": returned_image_attention_mask,
"num_img_tokens": returned_num_img_tokens,
}
return data
class Phi4MMImageEncoder(nn.Module):
"""Image embedding."""
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
model_dir: str = "") -> None:
super().__init__()
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
embd_drop = config.embd_pdrop if hasattr(
config, 'embd_pdrop') else config.embed_pdrop
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# layer_idx to output the img features
if isinstance(config.img_processor, dict):
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get(
'type_feature', 'patch')
else:
self.layer_idx = -2
self.type_feature = 'patch'
self.img_processor = get_siglip_vision_model(
_flash_attn_2_enabled=True)
pe_weight = self.img_processor.embeddings.position_embedding.weight
L, D = pe_weight.size()
H = int(math.sqrt(L))
assert H**2 == L, f'position embedding size {L} is not square'
if H % 2 != 0:
self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
H += 1
image_dim_out = D
# ((448/14)//2)**2
self.num_img_tokens = (H // 2)**2
self.base_feat_height_target = H
self.image_dim_out = image_dim_out
self.img_sizes = None
self.image_attention_mask = None
# global_gn and sub_gn for hd transform, serves as line separator
self.use_hd_transform = True
self.with_learnable_separator = True
self.hd_transform_order = "sub_glb"
self.freeze_img_processor = False
self.crop_size = 448
# image token compression
self.image_token_compression_cls = 'avg_pool_2d'
self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
self.base_feat_height_reduction = 1
self.base_feat_height_target = self.base_feat_height_target // 2
# with_hd_transform and with_learnable_separator should have same value
assert self.use_hd_transform == self.with_learnable_separator, \
'use_hd_transform and with_learnable_separator should have same value'
assert self.use_hd_transform, \
'learnable separator is only for hd transform'
# 1024 * 4, merge spatial to channel dimension
self.glb_GN = nn.Parameter(
torch.zeros([
1, 1, self.image_dim_out * self.base_feat_height_reduction**2
]))
self.sub_GN = nn.Parameter(
torch.zeros([
1, 1, 1,
self.image_dim_out * self.base_feat_height_reduction**2
]))
dim_projection = hidden_size
depth = 2
layers = [
nn.Linear(image_dim_out * self.base_feat_height_reduction**2,
dim_projection)
]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.img_features = None
self.use_out_place_operations = False
def get_img_features(self,
img_embeds: torch.FloatTensor,
attention_mask=None) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
img_processor_output = self.img_processor(
img_embeds,
output_hidden_states=True,
patch_attention_mask=attention_mask)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
if TYPE_FEATURE == "patch":
patch_feature = img_feature
use_token_compression = self.image_token_compression is not None
use_padding = getattr(self, 'img_processor_padding',
None) is not None
if use_token_compression or use_padding:
# reshape to 2D tensor
width = int(math.sqrt(patch_feature.size(1)))
patch_feature = patch_feature.view(-1, width, width,
patch_feature.size(-1))
# convert to NCHW
patch_feature = patch_feature.permute(0, 3, 1, 2)
if use_padding:
patch_feature = self.img_processor_padding(patch_feature)
if use_token_compression:
patch_feature = self.image_token_compression(patch_feature)
# convert to NHWC
patch_feature = patch_feature.permute(0, 2, 3, 1)
patch_feature = patch_feature.view(
-1,
patch_feature.size(1) * patch_feature.size(2),
patch_feature.size(-1))
return patch_feature
raise NotImplementedError
def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
image_attention_mask: torch.Tensor) -> torch.FloatTensor:
"""
process image and return vision embeddings.
pixel_values: (num_images, num_crops, c, h, w)
image_sizes: [[h1, w1], [h2, w2]]
image_attention_mask: num_images x num_crops x 32 x 32
output: (num_images, num_img_tokens, hidden_size)
"""
# eg
# pixel_values: torch.Size([1, 7, 3, 448, 448])
# image_sizes: tensor([[ 896, 1344]], device='cuda:0')
# output: torch.Size([1, 1841, 3072])
if isinstance(self.img_projection, nn.Sequential):
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
else: # It's a single nn.Linear layer
target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype
img_sizes = image_sizes
num_images, num_crops, c, h, w = pixel_values.shape
bs = num_images
pixel_values = pixel_values.flatten(0, 1)
img_features = self.get_img_features(
pixel_values,
image_attention_mask.type(torch.BoolTensor).flatten(
0, 1).to(target_device))
base_feat_height_target = self.base_feat_height_target
base_resolution = self.crop_size
base_feat_height_reduction = self.base_feat_height_reduction
base_feat_height = base_feat_width = int(np.sqrt(
img_features.shape[1]))
assert base_feat_height == base_feat_height_target \
and base_feat_width == base_feat_height_target, \
f'base_feat_height: {base_feat_height},"\
f" base_feat_width: {base_feat_width}, "\
f"expect {base_feat_height_target} features for hd transform'
# bs x max_num_crops x (24x24) x C
img_features = img_features.view(bs, -1,
base_feat_height * base_feat_width,
self.image_dim_out)
C = self.image_dim_out
H = base_feat_height
output_imgs = []
output_len = []
# training is tensor, inference is list
if isinstance(img_sizes, torch.Tensor):
img_sizes = img_sizes.view(-1, 2)
for _bs in range(bs):
h, w = img_sizes[_bs]
h = h // base_resolution
w = w // base_resolution
B_ = h * w
# 1 x (24x24) x 1024
global_img_feature = img_features[_bs, :1]
# 1 x 12 x 12 x 4096
glb_img = global_img_feature.reshape(1, H, H, C).reshape(
1, H // base_feat_height_reduction, base_feat_height_reduction,
H // base_feat_height_reduction, base_feat_height_reduction,
C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
1, H // base_feat_height_reduction,
H // base_feat_height_reduction,
base_feat_height_reduction * base_feat_height_reduction *
C).contiguous()
temp_glb_GN = self.sub_GN.repeat(1,
H // base_feat_height_reduction,
1, 1)
# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
1, -1,
base_feat_height_reduction * base_feat_height_reduction * C)
# (max_num_crops-1) x (12x12) x C
sub_img = img_features[_bs, 1:]
# 16x574x1024
# get rid of padding sub_img
sub_img = sub_img[:B_]
# (num_crops, 12, 2, 12, 2, 1024) ->
# (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
sub_img = sub_img.reshape(B_, H, H, C).reshape(
B_, H // base_feat_height_reduction,
base_feat_height_reduction, H // base_feat_height_reduction,
base_feat_height_reduction,
C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
B_, -1, base_feat_height_reduction *
base_feat_height_reduction * C).contiguous()
sub_img = sub_img.reshape(
1, h, w, base_feat_height // base_feat_height_reduction,
base_feat_width // base_feat_height_reduction,
-1).permute(0, 1, 3, 2, 4, 5).reshape(
1, h * base_feat_height // base_feat_height_reduction,
w * base_feat_width // base_feat_height_reduction,
base_feat_height_reduction * base_feat_height_reduction *
C)
if image_attention_mask is not None and len(
image_attention_mask) > 0:
reshaped_image_attention_mask = image_attention_mask[
_bs, 1:B_ + 1, 0::2, 0::2].reshape(
1, h, w,
base_feat_height // base_feat_height_reduction,
base_feat_width // base_feat_height_reduction).permute(
0, 1, 3, 2, 4).reshape(
1, h * base_feat_height //
base_feat_height_reduction, w *
base_feat_width // base_feat_height_reduction)
useful_height = int(
reshaped_image_attention_mask[0, :, 0].sum().item())
useful_width = int(
reshaped_image_attention_mask[0, 0, :].sum().item())
sub_img = sub_img[:, :useful_height, :useful_width]
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
temp_len = int(
image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item(
)) + (useful_height +
1) + base_feat_height // base_feat_height_reduction
else:
temp_sub_GN = self.sub_GN.repeat(
1, h * base_feat_height // base_feat_height_reduction, 1,
1)
temp_len = int((h * w + 1) * self.num_img_tokens + 1 +
(h + 1) * base_feat_height //
base_feat_height_reduction)
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
1, -1,
base_feat_height_reduction * base_feat_height_reduction * C)
# (1, num_img_tokens, 1024*4)
# glb + sub
if self.hd_transform_order == 'glb_sub':
output_imgs.append(
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
elif self.hd_transform_order == 'sub_glb':
output_imgs.append(
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
else:
raise NotImplementedError(
f'hd_transform_order = {self.hd_transform_order}, "\
"not implemented')
#temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
assert temp_len == output_imgs[-1].shape[
1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\
"{output_imgs[-1].shape[1]}'
output_len.append(temp_len)
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_device).to(target_dtype))
img_set_tensor.append(img_feature_proj)
return img_set_tensor
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Tuple[NestedTensors]
"""Shape: `((batch_size, num_audios, 80, M), )"""
class Phi4MMAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
Args:
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
n_fft (int): FFT size. int > 0 [scalar]
n_mel (int): Mel filter size. int > 0 [scalar]
fmin (float): lowest frequency (in Hz). If None use 0.0.
float >= 0 [scalar]
fmax: highest frequency (in Hz). If None use sample_rate / 2.
float >= 0 [scalar]
Returns
out (numpy.ndarray): Mel transform matrix
[shape=(n_mels, 1 + n_fft/2)]
"""
bank_width = int(n_fft // 2 + 1)
if fmax is None:
fmax = sample_rate / 2
if fmin is None:
fmin = 0
assert fmin >= 0, "fmin cannot be negative"
assert (fmin < fmax <=
sample_rate / 2), "fmax must be between (fmin, samplerate / 2]"
def mel(f):
return 1127.0 * np.log(1.0 + f / 700.0)
def bin2mel(fft_bin):
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
def f2bin(f):
return int((f * n_fft / sample_rate) + 0.5)
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
klo = f2bin(fmin) + 1
khi = f2bin(fmax)
khi = max(khi, klo)
# Spec 2: SpeechLib uses triangles in Mel space
mlo = mel(fmin)
mhi = mel(fmax)
m_centers = np.linspace(mlo, mhi, n_mels + 2)
ms = (mhi - mlo) / (n_mels + 1)
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
for m in range(0, n_mels):
left = m_centers[m]
center = m_centers[m + 1]
right = m_centers[m + 2]
for fft_bin in range(klo, khi):
mbin = bin2mel(fft_bin)
if left < mbin < right:
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
return matrix
class LogFbankProcessor:
def __init__(self):
self._eightk_method = "fillzero"
self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
self._hamming400 = np.hamming(400) # for 16k audio
self._hamming200 = np.hamming(200) # for 8k audio
def extract_spectrogram(self, wav, fs):
"""Extract spectrogram features from waveform.
Args:
wav (1D array): waveform of the input
fs (int): sampling rate of the waveform, 16000 or 8000.
If fs=8000, the waveform will be resampled to 16000Hz.
Output:
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
D=80, and T is the number of frames.
"""
if wav.ndim > 1:
wav = np.squeeze(wav)
# by default, we extract the mean if stereo
if len(wav.shape) == 2:
wav = wav.mean(1)
# Resample to 16000 or 8000 if needed
if fs > 16000:
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
fs = 16000
elif 8000 < fs < 16000:
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
fs = 8000
elif fs < 8000:
raise RuntimeError(f"Unsupported sample rate {fs}")
if fs == 8000:
if self._eightk_method == "resample":
# Input audio is 8 kHz. Convert to 16 kHz before feature
# extraction
wav = scipy.signal.resample_poly(wav, 2, 1)
fs = 16000
# Do nothing here for fillzero method
elif fs != 16000:
# Input audio is not a supported sample rate.
raise RuntimeError(
f"Input data using an unsupported sample rate: {fs}")
preemphasis = 0.97
if fs == 8000:
n_fft = 256
win_length = 200
hop_length = 80
fft_window = self._hamming200
elif fs == 16000:
n_fft = 512
win_length = 400
hop_length = 160
fft_window = self._hamming400
# Spec 1: SpeechLib cut remaining sample insufficient for a hop
n_batch = (wav.shape[0] - win_length) // hop_length + 1
# Here we don't use stride_tricks since the input array may not satisfy
# memory layout requirement and we need writeable output
# Here we only use list of views before copy to destination
# so it is more efficient than broadcasting
y_frames = np.array(
[
wav[_stride:_stride + win_length]
for _stride in range(0, hop_length * n_batch, hop_length)
],
dtype=np.float32,
)
# Spec 2: SpeechLib applies preemphasis within each batch
y_frames_prev = np.roll(y_frames, 1, axis=1)
y_frames_prev[:, 0] = y_frames_prev[:, 1]
y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
S = np.fft.rfft(fft_window * y_frames, n=n_fft,
axis=1).astype(np.complex64)
if fs == 8000:
# Need to pad the output to look like 16 kHz data but with zeros in
# the 4 to 8 kHz bins.
frames, bins = S.shape
padarray = np.zeros((frames, bins))
S = np.concatenate((S[:, 0:-1], padarray),
axis=1) # Nyquist bin gets set to zero
spec = np.abs(S).astype(np.float32)
return spec
def extract_features(self, wav, fs):
"""Extract log filterbank features from waveform.
Args:
wav (1D array): waveform of the input
fs (int): sampling rate of the waveform, 16000 or 8000.
If fs=8000, the waveform will be resampled to 16000Hz.
Output:
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
D=80, and T is the number of frames.
"""
spec = self.extract_spectrogram(wav, fs)
spec_power = spec**2
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
log_fbank = np.log(fbank_power).astype(np.float32)
return log_fbank
@lru_cache
def audio_feature_extractor() -> LogFbankProcessor:
# Creates an instance of the audio processor, needed to extract the
# the audio features from the sound file
# LRU cache ensures that we only make one copy
return LogFbankProcessor()
def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
vit_patch_size, token_compression_factor):
"""
compute the number of tokens an image is expected to take up considering
the image encoder architecture and exclude output features containing
only padding pixels
for siglip, vit_image_size=448, vit_patch_size=14, so output will be
32x32 feature map
NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
"""
assert vit_image_size % vit_patch_size == 0, \
"vit_image_size must be divisible by vit_patch_size"
assert vit_image_size // vit_patch_size % token_compression_factor == 0, \
"vit_image_size // vit_patch_size must be divisible by "\
"token_compression_factor"
target_aspect_ratio, target_height, target_width = (
_find_target_aspect_ratio(image,
vit_image_size,
dynamic_hd_size,
min_num=1))
assert target_aspect_ratio[
0] * vit_image_size == target_width, \
f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
assert target_aspect_ratio[
1] * vit_image_size == target_height, \
f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
assert (target_height % vit_image_size == 0
and target_width % vit_image_size == 0)
padding_height, padding_width = _get_padding_size(image, target_height,
target_width)
assert padding_width == 0 or padding_height == 0, \
"padding_width or padding_height must be 0"
target_feat_width = target_width // vit_patch_size
target_feat_height = target_height // vit_patch_size
if padding_width >= vit_patch_size:
assert padding_height == 0, "padding_height not 0"
non_pad_feat_width = target_feat_width - math.floor(
padding_width / vit_patch_size)
non_pad_feat_height = target_feat_height
elif padding_height >= vit_patch_size:
assert padding_width == 0, "padding_width not 0"
non_pad_feat_height = target_feat_height - math.floor(
padding_height / vit_patch_size)
non_pad_feat_width = target_feat_width
else:
# small padding shorter than a vit patch
non_pad_feat_width = target_feat_width
non_pad_feat_height = target_feat_height
feat_width = non_pad_feat_width // token_compression_factor
feat_height = non_pad_feat_height // token_compression_factor
# NOTE it's possible that the non-padding feature is not divisible
if non_pad_feat_width % token_compression_factor != 0:
feat_width += 1
if non_pad_feat_height % token_compression_factor != 0:
feat_height += 1
num_hd_patch_tokens = feat_width * feat_height
num_hd_newline_tokens = feat_height
vit_feature_size = vit_image_size // vit_patch_size
num_global_image_tokens = (vit_feature_size // token_compression_factor)**2
num_sep_tokens = 1
num_global_image_newline_tokens = \
vit_feature_size // token_compression_factor
return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens +
num_hd_newline_tokens + num_global_image_newline_tokens)
def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]:
"""
Compute the output size of the `extract_features` method.
Args:
wav_length (int): Length of the input waveform in samples.
fs (int): Sampling rate of the waveform, either 16000 or 8000.
Returns:
tuple (int, int): Output size as (T, D), where:
T: Number of time frames.
D: Number of Mel filterbank bins (80).
"""
# Resample to 16000 or 8000 if needed
if fs > 16000:
wav_length //= fs // 16000
fs = 16000
elif 8000 <= fs < 16000:
# We'll resample to 16K from 8K
wav_length *= 2
fs = 16000
elif fs < 8000:
raise RuntimeError(f"Unsupported sample rate {fs}")
# Spectrogram parameters for 16 kHz
win_length = 400 # Frame length in samples
hop_length = 160 # Frame shift in samples
mel_bins = 80 # Number of mel filterbank bins
# Calculate number of frames (T)
T = (wav_length - win_length) // hop_length + 1
if T < 1:
raise ValueError("Waveform too short for given parameters.")
# Return time frames (T) and mel bins (D)
return T, mel_bins
def _get_audio_embed_sizes(audios, ctx: InputContext):
"""
Get the audio embedding sizes for each audio file.
Args:
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
waveform and sample rate.
ctx (InputContext): Input context.
Returns:
List[int]: List of audio embedding sizes.
"""
audio_embed_sizes = []
for audio in audios:
audio_data, sf = audio
audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf)
audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(),
audio_frames)
audio_embed_sizes.append(audio_embed_size)
return audio_embed_sizes
def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""):
"""
The following will search for `<|audio_{idx}|>` tokens and
return a mapping of audio placeholder tokens to audio placeholder token ids
based on the size of the audio embeddings.
Args:
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
waveform and sample rate.
ctx (InputContext): Input context.
prompt_str (str): The prompt string.
Returns:
Dict[str, List[int]]: Mapping of audio placeholder tokens to audio
placeholder token ids.
"""
if len(audios) == 0:
return {}
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str)
audio_ids = [int(audio_id) for audio_id in audio_ids]
assert len(audio_ids) == len(
audio_embed_sizes
), "Number of audio tokens and audio features do not match"
assert tuple(audio_ids) == tuple(range(1,
len(audio_ids) +
1)), "Audio ids are not in order!"
audio_id_to_input_ids = {
f"<|audio_{audio_id}|>":
[_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes)
}
return audio_id_to_input_ids
def _count_image_tokens(images, ctx: InputContext):
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
token_compression_factor = prepro_config['token_compression_factor']
image_token_counts = [
_compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
vit_patch_size, token_compression_factor)
for image in images
]
return image_token_counts
def _get_image_id_to_input_ids(images, prompt, ctx: InputContext):
if len(images) == 0:
return {}
image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt)
image_ids = [int(image_id) for image_id in image_ids]
assert len(image_ids) == len(
set(image_ids)), "Duplicate image tokens in prompt"
assert len(images) == len(
image_ids), "Number of images and image tokens in prompt do not match"
# NOTE the following assertion is not strictly necessary
assert tuple(image_ids) == tuple(range(1,
len(image_ids) +
1)), "Image ids are not in order"
image_token_counts = _count_image_tokens(images, ctx)
image_id_to_input_ids = {
f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens
for image_id, num_tokens in zip(image_ids, image_token_counts)
}
return image_id_to_input_ids
def input_processor_for_phi4mm(ctx: InputContext,
inputs: DecoderOnlyInputs) -> TokenInputs:
"""
Implements the input processor, which transforms the input prompt ids
to include the audio placeholder token. This will become the `input_ids`
in `forward` for the model.
Args:
ctx (InputContext): Input context.
inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids)
to process.
Returns:
TokenInputs: Processed inputs
"""
multi_modal_data = inputs.get("multi_modal_data")
if (multi_modal_data is None or
("audio" not in multi_modal_data and "image" not in multi_modal_data)):
# pure text input, so no need to do pre-processing
return inputs
prompt_str = inputs.get("prompt")
prompt_token_ids = inputs.get("prompt_token_ids")
# for offline_inference, we will get str input and we parse MM special
# tokens from it
# (ignore prompt_token_ids)
# for OAI server, we will get prompt_token_ids, where MM special tokens
# are already parsed
if 'audio' in multi_modal_data:
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
if prompt_str is not None:
audio_id_to_input_ids = _get_audio_id_to_input_ids(
audios, ctx, prompt_str=prompt_str)
audio_embed_sizes = []
elif prompt_token_ids is not None:
audio_id_to_input_ids = {}
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
else:
audio_id_to_input_ids = {}
audio_embed_sizes = []
if 'image' in multi_modal_data:
# PIL Image or list of PIL Images
images = multi_modal_data["image"]
if not isinstance(images, list):
images = [images]
if prompt_str is not None:
image_id_to_input_ids = _get_image_id_to_input_ids(
images, prompt_str, ctx)
image_token_counts = []
elif prompt_token_ids is not None:
image_id_to_input_ids = {}
image_token_counts = _count_image_tokens(images, ctx)
else:
image_id_to_input_ids = {}
image_token_counts = []
# Handle the case where the prompt is a string and we need to manually
# tokenize it.
# In this case, the `audio_id_to_input_ids` dict will be mapping from
# an audio placeholder
# string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the
# given audio length.
if prompt_str:
pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)"
prompt_chunk_strings = re.split(pattern, prompt_str)
prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""]
# Create the new input_ids with the placeholder image and audio
# tokens inserted
tokenizer = cached_tokenizer_from_config(ctx.model_config)
input_ids = []
has_imag, has_audio, has_user_text_input = False, False, False
for prompt_chunk_string in prompt_chunk_strings:
if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string):
input_ids.extend(image_id_to_input_ids[prompt_chunk_string])
has_imag = True
elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string):
input_ids.extend(audio_id_to_input_ids[prompt_chunk_string])
has_audio = True
else:
curr_token_ids = tokenizer(prompt_chunk_string).input_ids
if not has_user_text_input:
for token_id in curr_token_ids:
if token_id not in NON_USER_INPUT_TOKENS:
has_user_text_input = True
break
input_ids.extend(curr_token_ids)
if has_audio and has_imag and has_user_text_input:
raise ValueError(
"Phi4MMForCausalLM does not support text + audio + image" +
" inputs in the same prompt")
# Handle the case where the prompt is already tokenized
else:
assert prompt_token_ids is not None, \
"If string prompt isn't provided, prompt_token_ids must be"
i = 0
input_ids = prompt_token_ids
# only needed for later assertion
img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0
image_token_count_iter = iter(image_token_counts)
audio_embed_size_iter = iter(audio_embed_sizes)
while i < len(input_ids):
token_id = input_ids[i]
if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID:
token_count = next(audio_embed_size_iter)
audio_cnt += 1
elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID:
token_count = next(image_token_count_iter)
img_cnt += 1
else:
user_text_input_cnt += 1 if token_id not in \
NON_USER_INPUT_TOKENS else 0
i += 1
continue
tokens = [token_id] * token_count
input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
i += token_count
if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0:
raise ValueError(
"Phi4MMForCausalLM does not support text + audio + image" +
" inputs in the same prompt")
# If the below assertion fails, it might be that input pure-text
# messages contain image/audio special tokens literally
# (<|endoftext10|>, <|endoftext11|>).
assert (img_cnt == len(image_token_counts)), (
f"Number of image tokens in prompt_token_ids ({img_cnt}) "
f"does not match number of images ({len(image_token_counts)})")
assert (audio_cnt == len(audio_embed_sizes)), (
f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
f"does not match number of audios ({len(audio_embed_sizes)})")
# NOTE: Create a defensive copy of the original inputs
return token_inputs(
prompt_token_ids=input_ids,
prompt=prompt_str,
multi_modal_data=multi_modal_data,
)
def _compute_audio_embed_size(hf_config, audio_frames):
"""
Compute the audio embedding size based on the audio frames and
compression rate.
"""
compression_rate = hf_config.embd_layer['audio_embd_layer'][
'compression_rate']
# NOTE: this is a hard-coded value but might be configurable in the future
qformer_compression_rate = 1
integer = audio_frames // compression_rate
remainder = audio_frames % compression_rate
result = integer if remainder == 0 else integer + 1
integer = result // qformer_compression_rate
remainder = result % qformer_compression_rate
result = integer if remainder == 0 else integer + 1 # qformer compression
return result
def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int:
return 10000
def dummy_audio_for_phi4mm(audio_count: int) -> dict:
"""
Create dummy audio data for the Phi4MM model, which is used for profiling.
Args:
audio_count (int): Number of audio samples.
Returns:
dict: Dummy audio data.
"""
dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0)
return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count
def dummy_image_for_phi4mm(width: int, height: int):
image = Image.new('RGB', (width, height), color='black')
return image
def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]) -> DummyData:
"""
Create dummy sequence (input_ids) and audio data for the Phi4MM model,
which is used for profiling.
In this case, the sequence data is a bunch of 0s with a number of audio
tokens that correspond to the audio embed size of the
_AUDIO_MAX_SOUNDFILE_SIZE.
Args:
ctx (InputContext): Input context.
seq_len (int): Length of the sequence.
mm_counts (Mapping[str, int]): Multi-modal counts.
Returns:
Tuple: Dummy sequence data and dummy audio data.
"""
audio_count = mm_counts["audio"]
audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE,
DUMMY_SAMPLING_FREQUENCY)
audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(),
audio_frames)
image_count = mm_counts["image"]
dummy_image = get_max_dummy_image(ctx)
max_image_tokens = get_max_phi4mm_image_tokens(ctx)
total_image_tokens = image_count * max_image_tokens
if seq_len - audio_feature_size * audio_count - total_image_tokens < 0:
raise RuntimeError(
f"Phi4MM cannot process {audio_count} audios and {image_count}"
f"images in a prompt, please increase max_model_len to be at"
f" larger than "
f"{audio_feature_size * audio_count + total_image_tokens}"
" or reduce audio/image limit by --limit-mm-per-prompt.")
if audio_feature_size * audio_count > total_image_tokens:
seq_data = SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count),
(0, seq_len - audio_feature_size * audio_count),
)
mm_data = {
"audio": dummy_audio_for_phi4mm(audio_count),
}
else:
seq_data = SequenceData.from_prompt_token_counts(
(_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens),
(0, seq_len - total_image_tokens),
)
mm_data = {
"image": [dummy_image] * image_count,
}
return DummyData(seq_data, mm_data)
def input_mapper_for_phi4mm_audio(ctx: InputContext,
data: object) -> MultiModalInputs:
"""
This function is used to create the MultiModalInputs for the Phi4MM
(audio) model.
Specifically, for audio, we extract the audio features from the sound
file and create pairs of audio features and audio embed lengths (the
latter of which is used to repeat the audio placeholder token in the
input prompt IDs).
These pairs are used, downstream, in `_audio_features_to_embeddings`
(via `_process_audio_input`).
Note that the incoming audio data (each entry in `data`) is a tuple of
the audio data and the sampling frequency (e.g. from soundfile.read).
Args:
ctx (InputContext): Input context.
data (object): Audio data.
Returns:
MultiModalInputs: Multi-modal inputs.
"""
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalInputs()
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
audio, sf = audio_input
feature_extractor = audio_feature_extractor()
single_audio_features = feature_extractor.extract_features(audio, sf)
feat_stride = (1 if not hasattr(feature_extractor, "stride") else
feature_extractor.stride)
audio_frames = len(single_audio_features) * feat_stride
single_audio_embed_size = _compute_audio_embed_size(
ctx.get_hf_config(), audio_frames)
single_audio_feature_audio_len_pair = (
single_audio_features,
[single_audio_embed_size],
)
audio_features.append(single_audio_feature_audio_len_pair)
return MultiModalInputs({"audio_features": audio_features})
def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
# data: list of PIL images
if len(data) == 0:
return MultiModalInputs()
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
vit_patch_size)
return MultiModalInputs({
"pixel_values":
image_input_dict["pixel_values"],
"image_sizes":
image_input_dict["image_sizes"],
"image_attention_mask":
image_input_dict["image_attention_mask"],
"num_img_tokens":
image_input_dict["num_img_tokens"],
})
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in
tensors[1:]), "All tensors must have the same number of dimensions"
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
input_mapper_for_phi4mm_audio)
@MULTIMODAL_REGISTRY.register_input_mapper("image",
input_mapper_for_phi4mm_image)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_phi4mm_audio_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_phi4mm_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm)
class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"""
Implements the Phi-4-multimodal-instruct model in VLLM.
"""
# LoRA specific attributes
packed_modules_mapping = {
"qkv_proj": [
"qkv_proj",
],
"gate_up_proj": [
"gate_up_proj",
],
}
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
assert multimodal_config, "multimodal_config is required"
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.lora_config = lora_config
# Tensor/Pipeline parallel not supported for now.
assert get_tensor_model_parallel_world_size(
) == 1, "tensor parallel is not supported"
assert get_pp_group(
).world_size == 1, "pipeline parallel is not supported"
self.vision_encoder = Phi4MMImageEncoder(
config,
quant_config,
prefix="model.vision_embed_tokens",
model_dir=config._name_or_path)
if isinstance(config.embd_layer["audio_embd_layer"], dict):
embedding_config = {
"embedding_cls":
config.embd_layer["audio_embd_layer"]["embedding_cls"],
**config.embd_layer["audio_embd_layer"],
}
else:
embedding_config = {
"embedding_cls": self.config.embd_layer["embedding_cls"]
}
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def _audio_features_to_embeddings(
self,
input_ids: torch.Tensor,
input_features: List[torch.Tensor],
audio_input_sizes: torch.Tensor,
audio_projection_mode: str,
) -> torch.Tensor:
"""
Convert audio features to embeddings, which are used as input to the
model (via `inputs_embeds`).
Args:
input_ids (torch.Tensor): Input IDs (the prompt in this case).
input_features (list[torch.Tensor]): Input features (the audio
embeddings).
audio_input_sizes (list[torch.Tensor]): Audio input sizes (the
audio embed lengths to use for padding the audio placeholder token
in the input prompt IDs).
"""
# The audio projection can either be a single linear or Sequential,
# so handle both cases
if isinstance(self.embed_tokens_extend.audio_projection,
nn.Sequential):
target_dtype = self.embed_tokens_extend.audio_projection[
0].bias.dtype
else:
target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype
audio_input = [
input.unsqueeze(0).to(target_dtype) for input in input_features
]
kwargs = {
"wte": self.model.embed_tokens,
'audio_projection_mode': audio_projection_mode
}
audio_embeddings = self.embed_tokens_extend(input_ids, audio_input,
audio_input_sizes,
**kwargs)
audio_embeddings = audio_embeddings.to(target_dtype)
return audio_embeddings
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
"""
Parse and validate the audio input to the model. This handles both
audio features and audio embeddings, but only the former is used for
now.
Args:
kwargs (object): Keyword arguments.
Returns:
Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
"""
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None:
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
data=audio_features)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(self, input_ids: torch.Tensor,
audio_input: Phi4MMAudioInputs,
audio_projection_mode: str) -> NestedTensors:
"""
Create the audio embeddings from the audio input, where the audio input
is pairs of audio features and audio embed lengths. The audio input is
created by `input_mapper_for_phi4mm_audio`.
Args:
input_ids (torch.Tensor): Input IDs (the prompt in this case,
before the audio token replication).
audio_input (Phi4MMAudioInputs): Audio input.
Returns:
NestedTensors: Audio embeddings
"""
if audio_input["type"] == "audio_embeds":
return audio_input["data"]
audio_features = audio_input["data"]
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
audio_feature = [i[0] for j in audio_features for i in j]
audio_feature_len = [i[1].item() for j in audio_features for i in j]
# Add the batch dim via `squeeze`
return self._audio_features_to_embeddings(
input_ids.unsqueeze(0),
audio_feature,
audio_feature_len,
audio_projection_mode,
).squeeze(0)
def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[Dict]:
pixel_values: Optional[Dict] = kwargs.get("pixel_values")
if pixel_values is None:
return None
image_sizes = kwargs.get("image_sizes")
image_attention_mask = kwargs.get("image_attention_mask")
num_img_tokens = kwargs.get("num_img_tokens")
assert image_sizes is not None and image_attention_mask is not None\
and num_img_tokens is not None, "Missing image inputs"
if isinstance(pixel_values, list):
assert pixel_values[0].dim() == 5, "Incorrect image inputs"
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
pixel_values = cat_with_pad(pixel_values, dim=0)
elif isinstance(pixel_values, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert pixel_values.dim() == 6, "Incorrect image inputs"
pixel_values = pixel_values.flatten(0, 1)
else:
raise ValueError("Incorrect pixel_values inputs")
if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
elif isinstance(image_attention_mask, torch.Tensor):
image_attention_mask = image_attention_mask.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
if isinstance(image_sizes, list):
image_sizes = torch.cat(image_sizes, dim=0)
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
n for num_tensor in num_img_tokens
for n in num_tensor.tolist()
]
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect image_attention_mask inputs")
return {
'pixel_values': pixel_values,
'image_sizes': image_sizes,
'image_attention_mask': image_attention_mask,
'num_img_tokens': num_img_tokens,
}
def merge_image_features_to_inputs_embeds(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_set_tensors: List[torch.Tensor],
):
position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(
as_tuple=True)
assert all([t.shape[0] == 1 for t in image_set_tensors
]), 'img_set_tensor should have shape (1, N_tokens, C)'
# Shape: (merged_N_tokens, C)
image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0)
image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(
inputs_embeds.device)
merged_embeds = inputs_embeds.index_put(
indices=position_tuple,
values=image_set_tensor,
accumulate=False,
)
return merged_embeds
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v = \
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if name.startswith(audio_proj_4v):
name = name.replace(
audio_proj_4v,
"embed_tokens_extend.audio_projection_for_vision")
name = (name.replace(
"model.embed_tokens_extend.audio_embed."\
"audio_projection.speech.",
"embed_tokens_extend.audio_projection.",
).replace(
"model.embed_tokens_extend.audio_embed.",
"embed_tokens_extend.",
).replace("model.embed_tokens_extend.image_embed.",
"vision_encoder."))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if name.endswith(".base_layer.weight"):
name = name.replace(".base_layer.weight", ".weight")
adjusted_weights[name] = weight
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
logger.debug("*** missing keys:")
for key in missing_keys:
logger.debug(key)
logger.debug("**** unexpected keys:")
for key in unexpected_keys:
logger.debug(key)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# Each entry in this is a pair of audio_features and audio_embed
# lengths
audio_input = self._parse_and_validate_audio_input(**kwargs)
image_inputs = self._parse_and_validate_image_input(**kwargs)
has_audio = audio_input is not None
has_image = image_inputs is not None
if has_audio:
audio_projection_mode = 'vision' if has_image else 'speech'
inputs_embeds = self._process_audio_input(
input_ids, audio_input, audio_projection_mode)
if has_image:
dtype = self.vision_encoder.img_processor.embeddings.\
patch_embedding.weight.dtype
pixel_values = image_inputs['pixel_values'].to(dtype)
image_sizes = image_inputs['image_sizes']
image_attention_mask = image_inputs['image_attention_mask']
image_set_tensors = self.vision_encoder(
pixel_values, image_sizes, image_attention_mask)
if not has_audio:
inputs_embeds = self.model.embed_tokens(input_ids)
inputs_embeds = self.merge_image_features_to_inputs_embeds(
input_ids, inputs_embeds, image_set_tensors)
if has_image or has_audio:
# multi-modal input, we have set inputs_embeds properly in
# previous steps
input_ids = None
else:
# text-only, we keep using original input_ids
inputs_embeds = None
hidden_states = self.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import abc
import math
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel)
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
def encoder_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict],
layer_cls: type,
idx: int = 0,
) -> Callable:
"""return encoder activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
if target_layer_cls.lower() == "transformer":
target_layer_cls = (
"EncoderLayer",
"ConformerEncoderLayer",
)
elif target_layer_cls.lower() == "attention":
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
checkpointing_interval = activation_checkpointing.get("interval", 1)
offloading = activation_checkpointing.get("offload", False)
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
if (idx % checkpointing_interval == 0
and layer_cls.__name__ in target_layer_cls):
if offloading:
return offload_wrapper
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
class ConformerEncoderLayer(nn.Module):
"""ConformerEncoder Layer module.
for more details see conformer paper:
https://arxiv.org/abs/2005.08100
This module implement the Conformer block layer.
Args:
d_model: int
attention dim.
ext_pw_out_channel: int
if > 0, ext_pw_out_channel is a dim channel size
for the last pointwise conv after swish activation.
depthwise_seperable_out_channel: int
if set different to 0, the number of
depthwise_seperable_out_channel will be used as a
channel_out of the second conv1d layer.
otherwise, it equal to 0, the second conv1d layer is skipped.
depthwise_multiplier: int
number of input_dim channels duplication. this value
will be used to compute the hidden channels of the Conv1D.
n_head: int
the number of heads for multihead attention module.
d_ffn: int
output size of the feed_forward blocks.
ext_pw_kernel_size: int
kernel size of the conv pointwise of the conformer.
kernel_size: int
kernel size.
dropout_rate: float
dropout rate.
causal: bool, optional
if set to True, convolution have no access
to future frames. default False.
batch_norm: bool, optional
if set to True, apply batchnorm before activation
in ConvModule layer of the conformer.
default False
activation: str, optional
activation function name,
one of ["relu", "swish", "sigmoid"],
sigmoid activation is only used with "glu_in_fnn=True",
default "relu".
chunk_se: int, optional
0 for offline SE.
1 for streaming SE, where mean is computed
by accumulated history until current chunk_se.
2 for streaming SE, where mean is computed
by only the current chunk.
default 0.
chunk_size: int, optional
chunk_size for cnn. default 18
conv_activation: str, optional
activation function used in ConvModule part
of the conformer, default "relu".
conv_glu_type: str, optional
activation function used for the glu inside
the ConvModule part of the conformer.
default: "sigmoid".
bias_in_glu: bool, optional
if set to True, use additive bias in the weight module
before GLU.
linear_glu_in_convm: bool, optional
if set to True, use GLULinear module,
otherwise, used GLUPointWiseConv module.
default to False.
attention_innner_dim: int, optional
if equal to -1, attention dim for linears k/q/v is
equal to d_model. otherwise attention_innner_dim is used.
default -1.
attention_glu_type: str, optional
activation function for glu used in the multihead attention,
default "swish".
activation_checkpointing: str, optional
a dictionarry of {"module","interval","offload"}, where
"module": str
accept ["transformer", "attention"] to select
which module should do activation checkpointing.
"interval": int, default 1,
interval of applying activation checkpointing,
interval = 1 means that we apply checkpointing
on every layer (if activation), otherwise,
we apply it every x interval.
"offload": bool, default False,
if set to True, we offload activation to cpu and
reload it during backward, otherwise,
we recalculate activation in backward.
default "".
export: bool, optional
if set to True, it remove the padding from convolutional layers
and allow the onnx conversion for inference.
default False.
use_pt_scaled_dot_product_attention: bool, optional
if set to True, use pytorch's scaled dot product attention
implementation in training.
attn_group_sizes: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
attn_group_sizes = attenion_heads = Multi-Query Attention
"""
def __init__(
self,
d_model=512,
ext_pw_out_channel=0,
depthwise_seperable_out_channel=256,
depthwise_multiplier=1,
n_head=4,
d_ffn=2048,
ext_pw_kernel_size=1,
kernel_size=3,
dropout_rate=0.1,
causal=False,
batch_norm=False,
activation="relu",
chunk_se=0,
chunk_size=18,
conv_activation="relu",
conv_glu_type="sigmoid",
bias_in_glu=True,
linear_glu_in_convm=False,
attention_innner_dim=-1,
attention_glu_type="swish",
activation_checkpointing="",
export=False,
use_pt_scaled_dot_product_attention=False,
attn_group_sizes: int = 1,
):
super().__init__()
self.feed_forward_in = FeedForward(
d_model=d_model,
d_inner=d_ffn,
dropout_rate=dropout_rate,
activation=activation,
bias_in_glu=bias_in_glu,
)
self.self_attn = encoder_checkpoint_wrapper(
activation_checkpointing,
MultiHeadedAttention,
)(MultiHeadedAttention(
n_head,
d_model,
dropout_rate,
attention_innner_dim,
attention_glu_type,
bias_in_glu,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes,
))
self.conv = ConvModule(
d_model,
ext_pw_out_channel,
depthwise_seperable_out_channel,
ext_pw_kernel_size,
kernel_size,
depthwise_multiplier,
dropout_rate,
causal,
batch_norm,
chunk_se,
chunk_size,
conv_activation,
conv_glu_type,
bias_in_glu,
linear_glu_in_convm,
export=export,
)
self.feed_forward_out = FeedForward(
d_model=d_model,
d_inner=d_ffn,
dropout_rate=dropout_rate,
activation=activation,
bias_in_glu=bias_in_glu,
)
self.layer_norm_att = nn.LayerNorm(d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self,
x,
pos_k,
pos_v,
mask,
relative_attention_bias: Optional[Tensor] = None,
):
"""ConformerEncoder forward.
Args:
x: torch.Tensor
input feature of shape (batch, max_time_in, size)
pos_k: torch.Tensor
positional key embedding.
mask: torch.Tensor
mask for x (batch, max_time_in)
relative_attention_bias: Optional[torch.Tensor]
bias added to attention logits w.r.t. relative positions
(1, n_head, time1, time2)
"""
x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x)
x = x + self.self_attn(
norm_x,
norm_x,
norm_x,
pos_k,
pos_v,
mask,
relative_attention_bias=relative_attention_bias,
)
x = x + self.conv(x)
x = x + 0.5 * self.feed_forward_out(x)
out = self.layer_norm(x)
return out, pos_k, pos_v, mask
class TransformerEncoderBase(abc.ABC, nn.Module):
"""The Base class for Transformer based encoders
Please set causal = True in streaming model
Args:
input_size: int
input feature dimension.
chunk_size: int, list(int)
Number of frames for each chunk
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training
Some examples for the 2 cases:
chunk_size = 12
chunk_size = [6, 8, 12, 24]
left_chunk: int, list(int)
Number of chunks used for masking in streaming mode.
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training. When
chunk_size is a list, left_chunk must be a list with same length.
Some examples for the 2 cases:
left_chunk = 6
left_chunk = [12, 9, 6, 3]
attention_dim: int, optional
attention dimension. default 256.
attention_heads: int, optional
the number of heads. default 4
input_layer: str, optional
input layer type before Conformer,
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
default "conv2d"
cnn_out: int, optional
the number of CNN channels before Conformer.
default -1.
cnn_layer_norm: bool, optional
layer norm between Conformer and the first CNN.
default False.
time_reduction: int, optional
time reduction factor
default 4
dropout_rate: float, optional
dropout rate. default 0.1
padding_idx: int, optional
padding index for input_layer=embed
default -1
relative_attention_bias_args: dict, optional
use more efficient scalar bias-based relative multihead attention
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
transformer_base.py)
positional_dropout_rate: float, optional
dropout rate after positional encoding. default 0.0
nemo_conv_settings: dict, optional
A dictionary of settings for NeMo Subsampling.
default None
conv2d_extra_padding: str, optional
Add extra padding in conv2d subsampling layers. Choices are
(feat, feat_time, none, True).
if True or feat_time, the extra padding is added into non full
supraframe utts in batch.
Default: none
attention_group_size: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
Attention
attention_group_size = attenion_heads = Multi-Query Attention
"""
def __init__(
self,
input_size,
chunk_size,
left_chunk,
attention_dim=256,
attention_heads=4,
input_layer="nemo_conv",
cnn_out=-1,
cnn_layer_norm=False,
time_reduction=4,
dropout_rate=0.0,
padding_idx=-1,
relative_attention_bias_args=None,
positional_dropout_rate=0.0,
nemo_conv_settings=None,
conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none",
attention_group_size=1,
encoder_embedding_config=None,
):
super().__init__()
self.input_size = input_size
self.input_layer = input_layer
self.chunk_size = chunk_size
self.left_chunk = left_chunk
self.attention_dim = attention_dim
self.num_heads = attention_heads
self.attention_group_size = attention_group_size
self.time_reduction = time_reduction
self.nemo_conv_settings = nemo_conv_settings
self.encoder_embedding_config = encoder_embedding_config
if self.input_layer == "nemo_conv":
default_nemo_conv_settings = {
"subsampling": "dw_striding",
"subsampling_factor": self.time_reduction,
"feat_in": input_size,
"feat_out": attention_dim,
"conv_channels": 256,
"subsampling_conv_chunking_factor": 1,
"activation": nn.ReLU(),
"is_causal": False,
}
# Override any of the defaults with the incoming, user settings
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert (
i not in nemo_conv_settings
), "{i} should be specified outside of the NeMo dictionary"
self.embed = NemoConvSubsampling(**default_nemo_conv_settings, )
else:
raise ValueError("unknown input_layer: " + input_layer)
self.pos_emb = AbsolutePositionalEncoding(attention_dim,
positional_dropout_rate)
self.relative_attention_bias_type = (
relative_attention_bias_args.get("type")
if relative_attention_bias_args else None)
if self.relative_attention_bias_type == "t5":
assert (self.num_heads % self.attention_group_size == 0
), "attention_group_size must divide n_head"
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
self.num_heads // self.attention_group_size,
max_distance=relative_attention_bias_args.get(
"t5_bias_max_distance", 1000),
symmetric=relative_attention_bias_args.get(
"t5_bias_symmetric", False),
)
else:
raise NotImplementedError
def post_init(self, init_model_config):
pretrained_speech_encoder_path = init_model_config.get(
"pretrained_speech_encoder_path", None)
if pretrained_speech_encoder_path:
model_state = torch.load(pretrained_speech_encoder_path,
map_location="cpu")
encoder_state_dict = {}
for k, v in model_state.items():
if "encoder." in k:
tmp_k = k.replace("encoder.", "")
encoder_state_dict[tmp_k] = v
if hasattr(self, "encoder_embedding"):
del self.encoder_embedding
self.load_state_dict(encoder_state_dict)
if not hasattr(self, "encoder_embedding"):
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
def compute_lens_change(self, feature_lens):
"""feature_lens: int
return updated feature lens.
This used to return a different lambda function for each case that
computed the right thing. That does not work within Torchscript.
If you really need this to be faster, create nn.Module()-s for all
the cases and return one of them. Torchscript does support that.
"""
if self.input_layer == "nemo_conv":
# Handle the special causal case
subsampling_causal_cond = self.nemo_conv_settings.get(
"subsampling", "dw_striding") in [
"dw_striding",
"striding",
"striding_conv1d",
]
is_causal = self.nemo_conv_settings.get("is_causal", False)
if is_causal and subsampling_causal_cond:
lens_change = (torch.ceil(feature_lens /
self.time_reduction).long()
if isinstance(feature_lens, Tensor) else
math.ceil(feature_lens / self.time_reduction))
feature_lens_remainder = feature_lens % self.time_reduction
if isinstance(feature_lens, Tensor):
lens_change[feature_lens_remainder != 1] += 1
elif feature_lens_remainder != 1:
lens_change += 1
return lens_change
ceil_func = (math.ceil
if isinstance(feature_lens, int) else torch.ceil)
return ceil_func(feature_lens / self.time_reduction)
@abc.abstractmethod
def forward(self):
"""Abstract forward method implementation."""
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
"""If chunk size is a list, we will randomly select a chunk size."""
if chunk_size is None:
chunk_size = self.chunk_size
if left_chunk is None:
left_chunk = self.left_chunk
if isinstance(chunk_size, list):
# Variable chunk size during training
chunk_size_index = int(
torch.randint(low=0, high=len(chunk_size), size=(1, )))
chunk_size_train_eff = chunk_size[chunk_size_index]
if not isinstance(left_chunk, list):
raise ValueError(
"Since chunk_size is a list, left_chunk must be a list")
if len(left_chunk) != len(chunk_size):
raise ValueError(
"The length of left_chunk must be the same as length of "\
"chunk_size."
)
left_chunk_train_eff = left_chunk[chunk_size_index]
else:
chunk_size_train_eff = chunk_size
left_chunk_train_eff = left_chunk
return chunk_size_train_eff, left_chunk_train_eff
def _get_embed_class(self, embed):
# pylint: disable=protected-access
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
embed_class = embed
if is_embed_using_act_chkpt:
embed_class = embed._checkpoint_wrapped_module
if is_embed_fsdp_wrapped:
embed_class = embed.module
return embed_class
def _forward_embeddings_core(self, input_tensor, masks):
embed_class = self._get_embed_class(self.embed)
assert isinstance(embed_class, NemoConvSubsampling)
input_tensor, masks = self.embed(input_tensor, masks)
return input_tensor, masks
def _position_embedding(self, input_tensor):
pos_k = None
pos_v = None
if self.relative_attention_bias_layer is None:
input_tensor = self.pos_emb(
input_tensor) # default to add abs sinusoid embedding
return pos_k, pos_v
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
chunk_size_train_eff, left_chunk_train_eff = \
self._chunk_size_selection(chunk_size, left_chunk)
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
# avoid randomness when run evaluation or decoding
if self.training and np.random.rand() > 0.5:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx = seq_len - chunk_start_idx
chunk_start_idx = chunk_start_idx[::-1]
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
enc_streaming_mask = (adaptive_enc_mask(
seq_len, chunk_start_idx,
left_window=left_chunk_train_eff).unsqueeze(0).expand(
[batch_size, -1, -1]))
return enc_streaming_mask
def forward_embeddings(self,
xs_pad,
masks,
chunk_size_nc=None,
left_chunk_nc=None):
"""Forwarding the inputs through the top embedding layers
Args:
xs_pad: torch.Tensor
input tensor
masks: torch.Tensor
input mask
chunk_size_nc: (optional, default is None) chunk size for
non-causal layers
left_chunk_nc: (optional, default is None) # of left chunks for
non-causal layers
"""
# pylint: disable=R0915
# get new lens.
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
if seq_len <= 0:
raise ValueError(
f"""The sequence length after time reduction is invalid:
{seq_len}. Your input feature is too short. Consider
filtering out the very short sentence from data
loader""", )
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
self.chunk_size,
self.left_chunk)
if xs_pad.is_cuda:
enc_streaming_mask = enc_streaming_mask.cuda()
xs_pad = xs_pad.cuda()
input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(
input_tensor, masks)
streaming_mask = enc_streaming_mask
if streaming_mask is not None and masks is not None:
hs_mask = masks & streaming_mask
elif masks is not None:
hs_mask = masks
else:
hs_mask = streaming_mask
if chunk_size_nc is not None:
enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc)
if xs_pad.is_cuda:
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc
else:
hs_mask_nc = enc_streaming_mask_nc
else:
hs_mask_nc = None
pos_k, pos_v = self._position_embedding(input_tensor)
if chunk_size_nc is None:
return input_tensor, pos_k, pos_v, hs_mask, masks
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
def get_offset(self):
"""Returns offset used when retaining inputs for decoding.
This is essentially, how many additional frames have to be added to
the front-end CNN input to ensure it can produce a single output.
So if the "padding" parameter is 0, typically offset will be > 0.
"""
return get_offset(self.input_layer, self.time_reduction)
class ConformerEncoder(TransformerEncoderBase):
"""ConformerEncoder module.
see original paper for more details:
https://arxiv.org/abs/2005.08100
Please set causal = True in streaming model
Args:
input_size: int
input feature dimension.
chunk_size: int, list(int)
Number of frames for each chunk
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training
Some examples for the 2 cases:
chunk_size = 12
chunk_size = [6, 8, 12, 24]
left_chunk: int, list(int)
Number of chunks used for masking in streaming mode.
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training. When
chunk_size is a list, left_chunk must be a list with same length.
Some examples for the 2 cases:
left_chunk = 6
left_chunk = [12, 9, 6, 3]
left_chunk: int
number of chunks used for masking in streaming mode.
num_lang: int
This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models.
default None.
attention_dim: int, optional
attention dimension. default 256.
attention_heads: int, optional
the number of heads. default 4
linear_units:
the number of units of position-wise feed forward.
default 2048
num_block:
number of Transformer layer. default 6
dropout_rate: float, optional
dropout rate. default 0.1
input_layer: str, optional
input layer type before Conformer,
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
default "conv2d"
causal: bool, optional
if set to True, convolution have no access
to future frames. default False.
batch_norm: bool, optional
if set to True, apply batchnorm before activation
in ConvModule layer of the conformer.
default False
cnn_out: int, optional
the number of CNN channels before Conformer.
default -1.
cnn_layer_norm: bool, optional
layer norm between Conformer and the first CNN.
default False.
ext_pw_out_channel: int, optional
the number of channel for CNN
before depthwise_seperable_CNN.
If 0 then use linear. default 0.
ext_pw_kernel_size: int, optional
kernel size of N before depthwise_seperable_CNN.
only work for ext_pw_out_channel > 0.
default 1
depthwise_seperable_out_channel: int, optional
the number of channel for
depthwise_seperable_CNN.
default 256.
depthwise_multiplier: int, optional
the number of multiplier for
depthwise_seperable_CNN.
default 1.
chunk_se: int, optional
0 for offline SE.
1 for streaming SE, where mean is computed
by accumulated history until current chunk_se.
2 for streaming SE, where mean is computed
by only the current chunk.
default 0.
kernel_size: int, optional
the number of kernels for depthwise_seperable_CNN.
default 3.
activation: str, optional
FeedForward block activation.
one of ["relu", "swish", "sigmoid"]
default "relu".
conv_activation: str, optional
activation function used in ConvModule part
of the conformer, default "relu".
conv_glu_type: str, optional
activation used use glu in depthwise_seperable_CNN,
default "sigmoid"
bias_in_glu: bool, optional
if set to True, use additive bias in the weight module
before GLU. default True
linear_glu_in_convm: bool, optional
if set to True, use GLULinear module,
otherwise, used GLUPointWiseConv module.
default to False.
attention_glu_type: str
only work for glu_in_attention !=0
default "swish".
export: bool, optional
if set to True, it remove the padding from convolutional layers
and allow the onnx conversion for inference.
default False.
activation_checkpointing: str, optional
a dictionarry of {"module","interval","offload"}, where
"module": str
accept ["transformer", "attention"] to select
which module should do activation checkpointing.
"interval": int, default 1,
interval of applying activation checkpointing,
interval = 1 means that we apply checkpointing
on every layer (if activation), otherwise,
we apply it every x interval.
"offload": bool, default False,
if set to True, we offload activation to cpu and
reload it during backward, otherwise,
we recalculate activation in backward.
default "".
extra_layer_output_idx: int
the layer index to be exposed.
relative_attention_bias_args: dict, optional
use more efficient scalar bias-based relative multihead attention
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
transformer_base.py)
time_reduction: int optional
time reduction factor
default 4
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
dot product attention in training.
Default: False
nemo_conv_settings: dict, optional
A dictionary of settings for NeMo Subsampling.
default: None
usage: nemo_conv_settings=
{
"subsampling":
dw_striding/striding/dw_striding_conv1d/striding_conv1d,
"conv_channels": int,
"subsampling_conv_chunking_factor": int,
"is_causal": True/False
}
conv2d_extra_padding: str, optional
Add extra padding in conv2d subsampling layers. Choices are
(feat, feat_time, none, True)
Default: none
replication_pad_for_subsample_embedding: For batched-streaming
decoding, use "replication" padding for the cache at start of
utterance.
Default: False
attention_group_size: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
Attention
attention_group_size = attenion_heads = Multi-Query Attention
"""
extra_multi_layer_output_idxs: List[int]
def __init__( # pylint: disable-all
self,
input_size,
chunk_size,
left_chunk,
num_lang=None,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
input_layer="nemo_conv",
causal=True,
batch_norm=False,
cnn_out=-1,
cnn_layer_norm=False,
ext_pw_out_channel=0,
ext_pw_kernel_size=1,
depthwise_seperable_out_channel=256,
depthwise_multiplier=1,
chunk_se=0,
kernel_size=3,
activation="relu",
conv_activation="relu",
conv_glu_type="sigmoid",
bias_in_glu=True,
linear_glu_in_convm=False,
attention_glu_type="swish",
export=False,
extra_layer_output_idx=-1,
extra_multi_layer_output_idxs=[], # noqa
activation_checkpointing="",
relative_attention_bias_args=None,
time_reduction=4,
use_pt_scaled_dot_product_attention=False,
nemo_conv_settings=None,
conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none",
replication_pad_for_subsample_embedding=False,
attention_group_size=1,
encoder_embedding_config=None,
):
super().__init__(
input_size,
chunk_size,
left_chunk,
attention_dim,
attention_heads,
input_layer,
cnn_out,
cnn_layer_norm,
time_reduction,
dropout_rate=dropout_rate,
relative_attention_bias_args=relative_attention_bias_args,
positional_dropout_rate=0.0,
nemo_conv_settings=nemo_conv_settings,
conv2d_extra_padding=conv2d_extra_padding,
attention_group_size=attention_group_size,
encoder_embedding_config=encoder_embedding_config,
)
self.num_blocks = num_blocks
self.num_lang = num_lang
self.kernel_size = kernel_size
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
self.embed)
self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding)
assert (self.num_heads % attention_group_size == 0
), "attention_group_size must divide n_head"
self.num_heads_k = self.num_heads // attention_group_size
self.encoders = repeat(
num_blocks,
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
ConformerEncoderLayer, i)
(ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=
depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
ext_pw_kernel_size=ext_pw_kernel_size,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
causal=causal,
batch_norm=batch_norm,
activation=activation,
chunk_se=chunk_se,
chunk_size=chunk_size,
conv_activation=conv_activation,
conv_glu_type=conv_glu_type,
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=attn_checkpointing(
activation_checkpointing, i),
export=export,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
)),
)
self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
# the device and the needed dtype:
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
def init_relative_attention_bias(self, input_tensor):
if self.relative_attention_bias_layer:
return self.relative_attention_bias_layer(input_tensor)
def calculate_hs_mask(self, xs_pad, device, mask):
max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
self.chunk_size,
self.left_chunk)
enc_streaming_mask = enc_streaming_mask.to(device)
if mask is None:
return enc_streaming_mask
feature_lens = mask.sum(1)
padding_length = feature_lens
pad_mask = (torch.arange(0, max_audio_length,
device=device).expand(padding_length.size(0),
-1)
< padding_length.unsqueeze(1))
pad_mask = pad_mask.unsqueeze(1)
pad_mask = pad_mask & enc_streaming_mask
return pad_mask
@torch.jit.ignore
def forward(self, xs_pad, masks):
"""Conformer Forward function
Args:
xs_pad: torch.Tensor
input tensor
masks: torch.Tensor
post-embedding input lengths
"""
xs_pad = self.encoder_embedding(xs_pad)
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
xs_pad, masks)
unfolded = False
ori_bz, seq_len, D = input_tensor.shape
max_seq_len = 500 #maximum position for absolute positional encoding
if seq_len > max_seq_len:
# audio sequence is longer than max_seq_len, unfold it into chunks
# of max_seq_len
unfolded = True
# the unfold op will drop residual frames, pad it to the multiple
# of max_seq_len
if seq_len % max_seq_len > 0:
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
else:
chunk_pad_size = 0
if chunk_pad_size > 0:
input_tensor_pad = F.pad(input_tensor,
(0, 0, 0, chunk_pad_size), "constant",
0)
input_tensor = input_tensor_pad.to(input_tensor.device)
input_tensor = unfold_tensor(input_tensor, max_seq_len)
if masks is not None:
# revise hs_mask here because the previous calculated hs_mask
# did not consider extra pad
subsampled_pad_mask = masks.squeeze(
1) # [bz, subsampled_unmask_seq_len]
extra_padded_subsamlped_pad_mask = F.pad(
subsampled_pad_mask, (0, chunk_pad_size), "constant",
False) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = \
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
masks_unfold = unfold_tensor(
extra_padded_subsamlped_pad_mask, max_seq_len
) # unfold the pad mask like we did to the input tensor
masks_unfold = masks_unfold.squeeze(
-1).bool() # unfold op does not support bool tensor
else:
masks_unfold = None
hs_mask = self.calculate_hs_mask(
input_tensor, input_tensor.device, masks_unfold
) # calculate hs_mask based on the unfolded pad mask
# layer_emb = None
relative_attention_bias = self.init_relative_attention_bias(
input_tensor)
_simplified_path = (self.extra_layer_output_idx == -1
and relative_attention_bias is None)
if _simplified_path:
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v,
hs_mask)
else:
for i, layer in enumerate(self.encoders):
input_tensor, _, _, _ = layer(
input_tensor,
pos_k,
pos_v,
hs_mask,
relative_attention_bias=relative_attention_bias,
)
# if i == self.extra_layer_output_idx:
# layer_emb = input_tensor
if unfolded:
embed_dim = input_tensor.shape[-1]
input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
# if we ever padded before unfolding, we need to remove the padding
if chunk_pad_size > 0:
input_tensor = input_tensor[:, :-chunk_pad_size, :]
return input_tensor, masks # , layer_emb
def gradient_checkpointing_enable(self):
pass
class WindowQformer(nn.Module):
"""Window-level Qformer"""
def __init__(
self,
window_size: int = 8,
num_queries: int = 1,
num_blocks: int = 2,
attention_dim: int = 512,
attention_heads: int = 8,
linear_units: int = 2048,
dropout_rate: float = 0.0,
normalize_before: bool = True,
):
super().__init__()
self.decoders = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=attention_dim,
nhead=attention_heads,
dim_feedforward=linear_units,
dropout=dropout_rate,
activation="relu",
batch_first=True,
norm_first=normalize_before, # TODO need to verify
) for _ in range(num_blocks)
])
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
if normalize_before else None)
self.window_size = window_size
self.gradient_checkpointing_enable = False
def enable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = True
def disable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = False
def forward(self, audio_embed, mask, embed_len=None):
"""forward decoder"""
# audio_embed: N x T x D => N x D x T
audio_embed = audio_embed.transpose(1, 2)
# audio_embed: N x D x 1 x T => N x DK x T'
padding = audio_embed.shape[-1] % self.window_size
if padding > 0:
audio_embed = F.pad(audio_embed, (0, self.window_size - padding),
"constant", 0)
embed_chunk = F.unfold(
audio_embed[..., None, :],
kernel_size=(1, self.window_size),
stride=(1, self.window_size),
)
bsz, _, slen = embed_chunk.shape
# N x D x K x T'
embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
# N x T' x K x D
embed_chunk = embed_chunk.transpose(1, 3).contiguous()
# NT' x K x D
embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
# NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders:
if self.gradient_checkpointing_enable and self.training:
q = checkpoint(
layer.__call__,
q,
embed_chunk,
None,
mask,
use_reentrant=True,
)
else:
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
if self.after_norm is not None:
q = self.after_norm(q)
if embed_len is not None:
embed_len = embed_len // self.window_size
# N x T' x D
out = q.view(bsz, slen, -1)
return out, embed_len
class AudioEmbedding(nn.Module):
"""Image embedding."""
def __init__(self, config: PretrainedConfig, **kwargs) -> None:
super().__init__()
self.config = config
# n_embed or hidden_size for text LM
hidden_size = (config.n_embd
if hasattr(config, "n_embd") else config.hidden_size)
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
else config.embed_pdrop)
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out = (
None # Set this variable according to the actual audio processor
)
self.layer_idx = -2
if (isinstance(config.audio_processor, dict)
and config.audio_processor.get("name", None) == "cascades"):
encoder_config = config.audio_processor.get("config", None)
assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config)
# fake initialization, create encoder_embedding layer only so that
# in decoding, all parameters can be loaded in
# from_pretrained_function in training, we do post init after
# from_pretrained function to make sure the correct initialization
self.encoder.post_init({})
audio_dim_out = encoder_config["attention_dim"]
n_mels = encoder_config["input_size"]
else:
raise NotImplementedError("")
assert (audio_dim_out
is not None), "Remember to set values for audio_dim_out"
self.audio_dim_out = audio_dim_out
self.audio_dim_in = n_mels
self.freeze_audio_processor = kwargs.get("freeze_audio_processor",
False)
self.downsample_rate = kwargs.get("downsample_rate", 1)
if kwargs.get("use_qformer", False):
qformer_config = kwargs.get("qformer_config", {})
qformer_config["attention_dim"] = audio_dim_out
self.qformer = WindowQformer(**qformer_config)
else:
self.qformer = None
if kwargs.get("use_conv_downsample", False):
assert (self.qformer is None
), "don't support use qformer and conv downsample together"
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
default_nemo_conv_settings = {
"subsampling": "dw_striding",
"subsampling_factor": self.downsample_rate,
"feat_in": audio_dim_out,
"feat_out": audio_dim_out,
"conv_channels": 256,
"subsampling_conv_chunking_factor": 1,
"activation": nn.ReLU(),
"is_causal": False,
}
# Override any of the defaults with the incoming, user settings
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert (
i not in nemo_conv_settings
), "{i} should be specified outside of the NeMo dictionary"
self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, )
else:
self.conv_ds = None
enable_gradient_checkpointing = kwargs.get(
"enable_gradient_checkpointing", False)
if enable_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if self.qformer:
self.qformer.enable_gradient_checkpointing()
projection_cls = kwargs.get("projection_cls", "linear")
if projection_cls == "linear":
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
elif projection_cls == "mlp":
# follow llava-v1.5's implementation
# (do not use image_projection and image_proj_norm)
dim_projection = hidden_size
depth = 2
self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds)
else self.downsample_rate)
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate,
dim_projection)
]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
self.audio_projection = nn.Sequential(*layers)
# NOTE vision-speech tasks use a separate projection layer
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate,
dim_projection)
]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
self.audio_projection_for_vision = nn.Sequential(*layers)
else:
raise NotImplementedError(
f"projection_cls = {projection_cls}, not implemented")
# TODO: audio sequence compression - Qformer
self.vocab_size = config.vocab_size
self.input_embeds = None
self.audio_embed_sizes = None
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
self.input_embeds = input_embeds
def set_audio_embed_sizes(self,
audio_embed_sizes: torch.LongTensor) -> None:
self.audio_embed_sizes = audio_embed_sizes
def get_audio_features(
self,
input_embeds: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
):
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds,
audio_attention_mask)
else:
audio_features, masks = self.encoder(input_embeds,
audio_attention_mask)
if self.qformer is not None:
audio_features, _ = self.qformer(audio_features, mask=None)
if self.conv_ds is not None:
if masks is not None:
masks = masks.squeeze(1)
audio_features, masks = self.conv_ds(audio_features, mask=masks)
if self.linear_downsample_rate != 1:
bs, seq_len, feat_dim = audio_features.size()
padding = seq_len % self.linear_downsample_rate
if padding > 0:
audio_features = F.pad(
audio_features,
(0, 0, 0, self.linear_downsample_rate - padding),
"constant",
0,
)
seq_len = audio_features.size(1)
audio_features = audio_features.view(
bs,
seq_len // self.linear_downsample_rate,
feat_dim * self.linear_downsample_rate,
)
if audio_projection_mode == 'speech':
audio_set_tensor = self.audio_projection(audio_features)
elif audio_projection_mode == 'vision':
audio_set_tensor = self.audio_projection_for_vision(audio_features)
else:
raise ValueError(
f"audio_projection_mode = {audio_projection_mode} not "\
"implemented"
)
return audio_set_tensor
def forward(
self,
input_ids: torch.LongTensor,
input_embeds: torch.FloatTensor,
audio_embed_sizes,
**kwargs,
) -> torch.FloatTensor:
"""
arguments:
input_ids: input text ids (B, U)
input_embeds: audio features (B, T, D) B: num audios in a sequence
"""
assert input_embeds is not None and len(input_embeds) == len(
audio_embed_sizes)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
with torch.no_grad():
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
as_tuple=False)
if not isinstance(input_embeds, list):
input_embeds = [input_embeds]
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
audio_set_tensor = [
self.get_audio_features(
input_embed, audio_projection_mode=audio_projection_mode)
for input_embed in input_embeds
]
with torch.no_grad():
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
if "wte" in kwargs:
# we use the token embedding layer from the huggingface model, this
# is REQUIRED to make sure we are using the loaded weights.
hidden_states = kwargs["wte"](input_ids)
else:
# otherwise, we use token embedding in pretrained mixformer from
# phi team
hidden_states = self.wte(input_ids)
if len(positions.tolist()) > 0:
assert sum(audio_embed_sizes) == len(
positions
), "please ensure the encoder outputs have the same length as"\
" defined in input_ids!"
idx = 0
for i in range(len(audio_embed_sizes)):
cnt = audio_embed_sizes[i]
assert audio_set_tensor[i].shape[0] == 1
hidden_states[
positions[idx, 0],
positions[idx, 1]:positions[idx, 1] + cnt,
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
hidden_states.dtype).to(hidden_states.device))
idx += cnt
else:
if self.training:
# hidden_states[:, 0:img_set_tensor.shape[0]] =
# hidden_states[:, 0:img_set_tensor.shape[0]] +
# 0 * img_set_tensor.to(hidden_states.dtype)
# .to(hidden_states.device)
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
.to(hidden_states.device)
if self.drop is not None:
hidden_states = self.drop(hidden_states)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import math
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, checkpoint_wrapper, offload_wrapper)
class Block(nn.Module):
"""Block abstract module"""
def __init__(self, input_size, output_size):
super().__init__()
self.input_size = input_size
self.output_size = output_size
def get_activation(name="relu"):
"""Select an activation function by name
Args:
name: str
activation function name,
one of ["relu", "gelu", "swish", "sigmoid"],
default "relu".
"""
name = name.lower()
if name == "relu":
return nn.ReLU(inplace=True)
if name == "gelu":
return nn.GELU()
if name == "swish":
return Swish()
if name == "sigmoid":
return torch.nn.Sigmoid()
return nn.Identity()
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
"""
The function is very important for Transformer Transducer Streaming mode
Args:
xs_len (int): sequence length
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
It also supports adaptive chunk size [0,10,15,45]
left_window (int): how many left chunks can be seen
right_window (int): how many right chunks can be seen. It is used for
chunk overlap model.
Returns:
mask (torch.Tensor): a mask tensor for streaming model
Torch 1.0.1
tensor([[1., 1., 0., 0.],
[0., 1., 1., 0.],
[0., 0., 1., 1.]])
Torch 1.4.1
tensor([[True., True., False., False.],
[False., True., True., False.],
[False., False., True., True.]])
"""
chunk_start_idx = torch.Tensor(chunk_start_idx).long(
) # first idx of each chunk, such as [0,18,36,48].
start_pad = torch.nn.functional.pad(
chunk_start_idx,
(1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
end_pad = torch.nn.functional.pad(
chunk_start_idx, (0, 1), value=x_len
) # append x_len to the end, so it becomes [0,18,36,48, x_len]
seq_range = torch.arange(0,
x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
idx = ((seq_range < end_pad) &
(seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
# boundary = end_pad[idx] # boundary size: [x_len]
seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
) # seq_range_expand size [x_len, x_len]
idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left]
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right]
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right
class Swish(nn.Module):
"""Implement Swish activation module.
From https://arxiv.org/pdf/2005.03191.pdf
"""
def __init__(self) -> None:
super().__init__()
self.act_fn = nn.Sigmoid()
def forward(self, x: Tensor) -> Tensor:
"""Apply Swish function
Args:
x: torch.Tensor
Input.
"""
return x * self.act_fn(x)
class GLU(nn.Module):
"""Implement Gated Linear Unit (GLU) module"""
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
super().__init__()
self.dim = dim
self.act_name = act_name.lower()
if self.act_name == "relu":
self.act_fn = nn.ReLU(inplace=True)
elif self.act_name == "gelu":
self.act_fn = nn.GELU()
elif self.act_name == "swish":
self.act_fn = Swish()
elif self.act_name == "sigmoid":
self.act_fn = nn.Sigmoid()
else:
self.act_fn = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
"""GLU forward
Apply Swish function on the first half of input matrices
with sigmoid of the second half.
Args:
x: torch.Tensor
Input.
"""
half_x, gate = x.chunk(2, dim=self.dim)
return half_x * self.act_fn(gate)
# TODO: Abdel, this can be improved using GLU module
class GLUPointWiseConv(nn.Module):
"""GLUPointWiseConv module
used for conformer architecture,
for more details see:
https://arxiv.org/pdf/2005.08100v1.pdf
Args:
input_dim: int
input channel size.
output_dim: int
output channel size.
kernel_size: int
kernel size
glu_type: str, optional
activation function one of
["sigmoid", "relu", "gelu"]
default "sigmoid".
bias_in_glu: bool, optional
use addtive bias in glu
causal: bool, optional
if set to True, padding is set to the half of
kernel size, ie, convolution can't see future frames.
default False.
"""
def __init__(
self,
input_dim,
output_dim,
kernel_size,
glu_type="sigmoid",
bias_in_glu=True,
causal=False,
):
super().__init__()
self.glu_type = glu_type
self.output_dim = output_dim
self.bias_in_glu = bias_in_glu
if causal:
self.ext_pw_conv_1d = nn.Conv1d(
input_dim,
output_dim * 2,
kernel_size,
1,
padding=(kernel_size - 1),
)
else:
self.ext_pw_conv_1d = nn.Conv1d(
input_dim,
output_dim * 2,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
if glu_type == "sigmoid":
self.glu_act = nn.Sigmoid()
elif glu_type == "relu":
self.glu_act = nn.ReLU()
elif glu_type == "gelu":
self.glu_act = nn.GELU()
elif glu_type == "swish":
self.glu_act = Swish()
else:
raise ValueError(f"Unsupported activation type {self.glu_act}")
if bias_in_glu:
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
def forward(self, x):
"""
Args:
x: torch.Tensor
input tensor
"""
# to be consistent with GLULinear, we assume the input always has the
# #channel (#dim) in the last dimension of the tensor, so need to
# switch the dimension first for 1D-Conv case
x = x.permute([0, 2, 1])
x = self.ext_pw_conv_1d(x)
if self.glu_type == "bilinear":
if self.bias_in_glu:
x = (x[:, 0:self.output_dim, :] + self.b1) * (
x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
else:
x = (x[:, 0:self.output_dim, :]) * (
x[:, self.output_dim:self.output_dim * 2, :])
else:
if self.bias_in_glu:
x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act(
x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
else:
x = (x[:, 0:self.output_dim, :]) * self.glu_act(
x[:, self.output_dim:self.output_dim * 2, :])
x = x.permute([0, 2, 1])
return x
class DepthWiseSeperableConv1d(nn.Module):
"""DepthWiseSeperableConv1d module used in Convnet module
for the conformer, for more details see:
https://arxiv.org/pdf/2005.08100v1.pdf
Args:
input_dim: int
input channel size.
depthwise_seperable_out_channel: int
if set different to 0, the number of
depthwise_seperable_out_channel will be used as a channel_out
of the second conv1d layer.
otherwise, it equal to 0, the second conv1d layer is skipped.
kernel_size: int
kernel_size
depthwise_multiplier: int
number of input_dim channels duplication. this value
will be used to compute the hidden channels of the Conv1D.
padding: int, optional
padding for the conv1d,
default: 0.
"""
def __init__(
self,
input_dim,
depthwise_seperable_out_channel,
kernel_size,
depthwise_multiplier,
padding=0,
):
super().__init__()
self.dw_conv = nn.Conv1d(
input_dim,
input_dim * depthwise_multiplier,
kernel_size,
1,
padding=padding,
groups=input_dim,
)
if depthwise_seperable_out_channel != 0:
self.pw_conv = nn.Conv1d(
input_dim * depthwise_multiplier,
depthwise_seperable_out_channel,
1,
1,
0,
)
else:
self.pw_conv = nn.Identity()
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
def forward(self, x):
"""
Args:
x: torch.Tensor
input tensor
"""
x = self.dw_conv(x)
if self.depthwise_seperable_out_channel != 0:
x = self.pw_conv(x)
return x
class ConvModule(nn.Module):
"""ConvModule Module for the conformer block.
for more details see:
https://arxiv.org/pdf/2005.08100v1.pdf
Args:
input_dim: int
input channel size.
ext_pw_out_channel: int
if > 0, ext_pw_out_channel is a dim channel size
for the last pointwise conv after swish activation.
depthwise_seperable_out_channel: int
if set different to 0, the number of
depthwise_seperable_out_channel
will be used as a channel_out of the second conv1d layer.
otherwise, it equal to 0, the second conv1d layer is skipped.
ext_pw_kernel_size: int
kernel size of the conv pointwise of the conformer.
kernel_size: int
kernel size.
depthwise_multiplier: int
number of input_dim channels duplication. this value
will be used to compute the hidden channels of the Conv1D.
dropout_rate: float
dropout rate.
causal: bool, optional
if set to True, convolution have no access
to future frames. default False.
batch_norm: bool, optional
if set to True, apply batchnorm before activation.
default False
chunk_se: int, optional
0 for offline SE.
1 for streaming SE, where mean is computed
by accumulated history until current chunk_se.
2 for streaming SE, where mean is computed
by only the current chunk.
chunk_size: int, optional
chunk size for cnn. default 18
activation: str, optional
activation function used in ConvModule,
default: "relu".
glu_type: str, optional
activation function used for the glu,
default: "sigmoid".
bias_in_glu: bool, optional
if set to True, use additive bias in the weight module
before GLU.
linear_glu_in_convm: bool, optional
if set to True, use GLULinear module,
otherwise, used GLUPointWiseConv module.
default to False.
export: bool, optional,
if set to True, padding is equal to 0. This is for inference,
or onnx export. Typically this is set by the export program or
the decoder program, and it isn't present in your config file.
default False
"""
def __init__(
self,
input_dim,
ext_pw_out_channel,
depthwise_seperable_out_channel,
ext_pw_kernel_size,
kernel_size,
depthwise_multiplier,
dropout_rate,
causal=False,
batch_norm=False,
chunk_se=0,
chunk_size=18,
activation="relu",
glu_type="sigmoid",
bias_in_glu=True,
linear_glu_in_convm=False,
export=False,
):
super().__init__()
self.layer_norm = nn.LayerNorm(input_dim)
self.input_dim = input_dim
self.ext_pw_out_channel = ext_pw_out_channel
self.ext_pw_kernel_size = ext_pw_kernel_size
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
self.glu_type = glu_type
self.bias_in_glu = bias_in_glu
self.linear_glu_in_convm = linear_glu_in_convm
self.causal = causal
self._add_ext_pw_layer()
self.batch_norm = batch_norm
self.kernel_size = kernel_size
if batch_norm:
self.bn_layer = nn.BatchNorm1d(input_dim)
self.act = get_activation(activation)
self.dropout = nn.Dropout(dropout_rate)
self.export = export
if causal:
padding = 0 if export else kernel_size - 1
else:
padding = (kernel_size - 1) // 2
self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
input_dim,
depthwise_seperable_out_channel,
kernel_size,
depthwise_multiplier,
padding=padding,
)
if depthwise_seperable_out_channel != 0:
if input_dim != depthwise_seperable_out_channel:
self.ln2 = nn.Linear(depthwise_seperable_out_channel,
input_dim)
else:
if depthwise_multiplier != 1:
self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
input_dim)
def _add_ext_pw_layer(self):
"""
This function is an extension of __init__ function
and dedicated to the convolution module creation
of the conformer.
"""
self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
nn.Identity()) # jit hacks.
self.squeeze_excitation = nn.Identity() # jit.
self.apply_ln1 = self.fix_len1 = False # jit.
if self.ext_pw_out_channel != 0:
if self.causal:
self.ext_pw_conv_1d = nn.Conv1d(
self.input_dim,
self.ext_pw_out_channel,
self.ext_pw_kernel_size,
1,
padding=(self.ext_pw_kernel_size - 1),
)
if self.ext_pw_kernel_size > 1:
self.fix_len1 = True
else:
self.fix_len1 = False
else:
self.ext_pw_conv_1d = nn.Conv1d(
self.input_dim,
self.ext_pw_out_channel,
self.ext_pw_kernel_size,
1,
padding=(self.ext_pw_kernel_size - 1) // 2,
)
self.fix_len1 = False
if self.linear_glu_in_convm:
self.glu = GLULinear(
self.input_dim,
self.ext_pw_out_channel,
self.glu_type,
self.bias_in_glu,
)
else:
self.glu = GLUPointWiseConv(
self.input_dim,
self.ext_pw_out_channel,
self.ext_pw_kernel_size,
self.glu_type,
self.bias_in_glu,
self.causal,
)
if self.input_dim != self.ext_pw_out_channel:
self.apply_ln1 = True
self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
else:
self.apply_ln1 = False
else:
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
def forward(self, x):
"""ConvModule Forward.
Args:
x: torch.Tensor
input tensor.
"""
x = self.layer_norm(x)
if self.ext_pw_out_channel != 0:
x = self.glu(x)
if self.causal and self.ext_pw_kernel_size > 1:
x = x[:, :-(self.ext_pw_kernel_size - 1), :]
if self.apply_ln1:
x = self.ln1(x)
else:
x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
x = x_0 + x_1
x = x.permute([0, 2, 1])
x = self.dw_sep_conv_1d(x)
if self.causal and self.kernel_size > 1:
x = x[:, :, :-(self.kernel_size - 1)]
if hasattr(self, "ln2"):
x = x.permute([0, 2, 1])
x = self.ln2(x)
x = x.permute([0, 2, 1])
if self.batch_norm:
x = self.bn_layer(x)
x = self.act(x)
if self.ext_pw_out_channel != 0:
x = self.ext_pw_conv_1d(x)
if self.fix_len1:
x = x[:, :, :-(self.ext_pw_kernel_size - 1)]
if self.apply_ln1:
x = x.permute([0, 2, 1])
x = self.ln1(x)
x = x.permute([0, 2, 1])
x = x.permute([0, 2, 1])
else:
x = x.unsqueeze(1).permute([0, 1, 3, 2])
x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
x = x.squeeze(1)
x = self.dropout(x)
return x
class GLULinear(nn.Module):
"""Linear + GLU module
Args:
input_dim: int
input size
output_dim: int
output size.
glu_type:
activation function name used in glu module.
default "sigmoid" (swish function).
bias_in_glu: bool, optional
If True, the addtive bias is added. Default False.
"""
def __init__(
self,
input_dim,
output_dim,
glu_type="sigmoid",
bias_in_glu=True,
):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
self.glu_act = GLU(-1, glu_type)
def forward(self, x):
"""GLULinear forward
Args:
x: torch.Tensor
inpute tensor.
"""
x = self.linear(x)
return self.glu_act(x)
class FeedForward(nn.Module):
"""FeedForward Module.
For more details see Conformer paper:
https://arxiv.org/pdf/2005.08100.pdf
Args:
d_model: int
input size.
d_inner: int
output size.
dropout_rate: float,
dropout rate.
activation: str,
activation function name,
one of ["relu", "swish", "sigmoid"],
sigmoid activation is only used with "glu_in_fnn=True",
default "sigmoid".
bias_in_glu: bool, optional
"""
def __init__(
self,
d_model,
d_inner,
dropout_rate,
activation="sigmoid",
bias_in_glu=True,
):
super().__init__()
self.d_model = d_model
self.d_inner = d_inner
self.layer_norm = nn.LayerNorm(d_model)
module = GLULinear(d_model, d_inner, activation, bias_in_glu)
self.net = nn.Sequential(
module,
nn.Dropout(dropout_rate),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout_rate),
)
def forward(self, x):
"""FeedForward forward function.
Args:
x: torch.Tensor
input tensor.
"""
out = self.net(self.layer_norm(x))
return out
#### positional encoding starts here
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward
compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class T5RelativeAttentionLogitBias(nn.Module):
"""
This module implements the relative position bias described in Section
2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
The Huggingface implementation is used as a reference
https://github.com/huggingface/transformers/blob/v4.30.0/src/
transformers/models/t5/modeling_t5.py#L435
Modifies attention as Q*K^T + B, where B is a learned scalar bias based
on relative position of the query and key. It is HxNxN, where H is the
number of heads, N is the sequence length.
I've made these modifications to the original T5 bias:
- Skipping of the bucketing step. Original T5 bias converted rel
position distances into logarithmically increasing buckets. This is
supposed to help with length generalization.
- I just directly use rel position index as bias values, as we don't
need length generalization (40s max is good enough for ASR encoder),
and it keeps ONNX export simple.
- I've also extended it so that biases can be asymmetric, the default
implementation treats L->R and R->L the same. Asymmetric was found to
yield better results in my experiments.
Args:
num_heads: int
Number of attention heads
num_buckets: int
Number of buckets to use for relative attention bias. This is the
size of the learnable bias parameter. Bucketing is not yet
supported, so this defaults to -1 which means no bucketing is
used (max_distance determines size of bias param).
max_distance: int
Maximum distance to use for relative attention bias. With
num_buckets=-1, this directly controls the max size of the bias
parameter. When num_buckets > 0 is supported, this will control
the maximum distance for logarithmic bucketing after which all
positions are in the same bucket.
symmetric: bool
Whether to use symmetric or asymmetric biases. symmetric=False uses
2x number of bias params to distinguish L->R from R->L. This was
found to be better for the encoder.
"""
def __init__(self,
num_heads,
num_buckets=-1,
max_distance=1000,
symmetric=False):
super().__init__()
self.num_heads = num_heads
self.num_buckets = num_buckets
self.max_distance = max_distance
self.symmetric = symmetric
self._skip_bucketing = self.num_buckets < 0
if self._skip_bucketing:
self.num_buckets = max_distance
else:
raise NotImplementedError(
"T5 attention bias with bucketed positions is not yet tested")
if not self.symmetric:
self.num_buckets *= 2
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
def forward(self, x):
# instantiate bias compatible with shape of x
maxpos = x.size(1)
context_position = torch.arange(maxpos,
device=x.device,
dtype=torch.long)[:, None]
memory_position = torch.arange(maxpos,
device=x.device,
dtype=torch.long)[None, :]
relative_position = memory_position - context_position
# clipping to a maximum distance using ops that play well with ONNX
# export
relative_position = relative_position.masked_fill(
relative_position < -self.max_distance, -self.max_distance)
relative_position = relative_position.masked_fill(
relative_position > self.max_distance - 1, self.max_distance - 1)
# mapping from relative position to index in the bias parameter
if self._skip_bucketing:
bias_idx = relative_position
else:
bias_idx = self._bucket_relative_position(relative_position)
if self.symmetric:
bias_idx = bias_idx.abs()
else:
bias_idx += self.num_buckets // 2
t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(
0) # [1, H, L, L]
return t5_rel_att_bias
def _bucket_relative_position(self, relative_position):
# this is a placeholder (isn't tested, likely buggy) using HuggingFace
# implem as a reference this also needs to be extended to support
# asymmetric +/- ve positions
relative_buckets = 0
if not self.causal:
self.num_buckets //= 2
relative_buckets += (relative_position > 0).to(
torch.long) * self.num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position,
torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = self.num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in
# positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact) /
math.log(self.max_distance / max_exact) *
(self.num_buckets - max_exact)).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, self.num_buckets - 1),
)
relative_buckets += torch.where(is_small, relative_position,
relative_position_if_large)
return relative_buckets
class AbsolutePositionalEncoding(nn.Module):
"""Absolute Positional encoding module.
This module implement Absolute sinusoidal positional encoding
from: https://arxiv.org/pdf/1706.03762.pdf
Args:
d_model: int
Input embedding size.
dropout_rate: float
dropout rate
max_len: int, optional
Maximum input length sequence, Default 5000
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x):
"""Reset the positional encodings.
Args:
x: torch.Tensor
"""
if self.pe is not None and self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x: torch.Tensor
Input tensor. shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, :x.size(1)]
return self.dropout(x)
#### forward embedding layers starts here
class MeanVarianceNormLayer(nn.Module):
"""Mean/variance normalization layer.
Will subtract mean and multiply input by inverted standard deviation.
Typically used as a very first layer in a model.
Args:
input_size: int
layer input size.
"""
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.register_buffer("global_mean", torch.zeros(input_size))
self.register_buffer("global_invstd", torch.ones(input_size))
self.global_mean: Optional[Tensor]
self.global_invstd: Optional[Tensor]
def forward(self, input_: Tensor) -> Tensor:
"""MeanVarianceNormLayer Forward
Args:
input_: torch.Tensor
input tensor.
"""
return (input_ - self.global_mean) * self.global_invstd
class CausalConv1D(nn.Conv1d):
"""
A causal version of nn.Conv1d where each step would have limited access to
locations on its right or left
All arguments are the same as nn.Conv1d except padding.
If padding is set None, then paddings are set automatically to make it a
causal convolution where each location would not see any steps on its right.
If padding is set as a list (size of 2), then padding[0] would be used as
left padding and padding[1] as right padding.
It would make it possible to control the number of steps to be accessible
on the right and left.
This mode is not supported when stride > 1. padding[0]+padding[1] should
be equal to (kernel_size - 1).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: Union[str, int] = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
self.cache_drop_size = None
if padding is None:
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
else:
if stride != 1 and padding != kernel_size - 1:
raise ValueError(
"No striding allowed for non-symmetric convolutions!")
if isinstance(padding, int):
self._left_padding = padding
self._right_padding = padding
elif (isinstance(padding, list) and len(padding) == 2
and padding[0] + padding[1] == kernel_size - 1):
self._left_padding = padding[0]
self._right_padding = padding[1]
else:
raise ValueError(f"Invalid padding param: {padding}!")
self._max_cache_len = self._left_padding
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def update_cache(self, x, cache=None):
if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
next_cache = cache
else:
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
next_cache = new_x[:, :, :-self.cache_drop_size]
else:
next_cache = new_x
next_cache = next_cache[:, :, -cache.size(-1):]
return new_x, next_cache
def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
x = super().forward(x)
if cache is None:
return x
else:
return x, cache
class CausalConv2D(nn.Conv2d):
"""
A causal version of nn.Conv2d where each location in the 2D matrix would
have no access to locations on its right or down
All arguments are the same as nn.Conv2d except padding which should be
set as None
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: Union[str, int] = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
if padding is not None:
raise ValueError(
"Argument padding should be set to None for CausalConv2D.")
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
padding = 0
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
def forward(
self,
x,
):
if self.training:
x = F.pad(
x,
pad=(
self._left_padding,
self._right_padding,
self._left_padding,
self._right_padding,
),
)
else:
x = F.pad(
x,
pad=(self._left_padding, self._right_padding, 0, 0),
)
x = super().forward(x)
return x
class NemoConvSubsampling(torch.nn.Module):
"""Convlutional subsampling module, taken from NeMo ASR
(https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
Striding Subsampling: "Speech-Transformer: A No-Recurrence
Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
et al. (https://ieeexplore.ieee.org/document/8462506)
Compared with the EncoderConv2D (`input_layer: custom`), this is a
much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
Moreover, depthwise convolutions are used to reduce FLOPs, but the first
layer is kept as a regular convolution so as not to degrade accuracy.
`Striding` and `dw_striding` are the same except that the latter uses
depthwise convolutions after the first layer, whereas the former does not.
Args:
subsampling_factor (int): Time reduction factor
feat_in (int): size of the input features
feat_out (int): size of the output features
subsampling (str): The subsampling technique, choose from
{"striding", "dw-striding", "striding_conv1d",
"dw_striding_conv1d"}
conv_channels (int): Number of channels for the convolution layers,
default is 256.
subsampling_conv_chunking_factor (int): Input chunking factor which
can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
activation (Module): activation function, default is nn.ReLU()
is_causal (bool): whether to use causal Conv1/2D, where each step will
have limited access to locations on its right or left
"""
def __init__(
self,
feat_in,
feat_out,
subsampling_factor=4,
subsampling="dw_striding",
conv_channels=256,
subsampling_conv_chunking_factor=1,
activation=nn.ReLU(), # noqa: B008
is_causal=False,
):
super().__init__()
self._subsampling = subsampling
self._conv_channels = conv_channels
self._feat_in = feat_in
self._feat_out = feat_out
if subsampling_factor % 2 != 0:
raise ValueError("Sampling factor should be a multiply of 2!")
self._sampling_num = int(math.log(subsampling_factor, 2))
self.subsampling_factor = subsampling_factor
self.is_causal = is_causal
self.subsampling_causal_cond = subsampling in (
"dw_striding",
"striding",
"striding_conv1d",
)
if (subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0):
raise ValueError(
"subsampling_conv_chunking_factor should be -1, 1, or a "\
"power of 2"
)
self.subsampling_conv_chunking_factor = \
subsampling_conv_chunking_factor
in_channels = 1
layers = []
if subsampling == "dw_striding":
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
# Layer 1
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
))
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
))
in_channels = conv_channels
layers.append(activation)
for i in range(self._sampling_num - 1):
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
groups=in_channels,
))
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
))
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
))
layers.append(activation)
in_channels = conv_channels
elif subsampling == "striding":
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
for i in range(self._sampling_num):
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
))
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
))
layers.append(activation)
in_channels = conv_channels
elif subsampling == "striding_conv1d":
in_channels = feat_in
self._stride = 2
self._kernel_size = 5
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
for i in range(self._sampling_num):
if self.is_causal:
layers.append(
CausalConv1D(
in_channels=in_channels,
out_channels=(feat_out if self._sampling_num == i +
1 else conv_channels),
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
))
else:
layers.append(
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=(feat_out if self._sampling_num == i +
1 else conv_channels),
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
))
layers.append(activation)
in_channels = conv_channels
elif subsampling == "dw_striding_conv1d":
in_channels = feat_in
self._stride = 2
self._kernel_size = 5
self._ceil_mode = False
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
# Layer 1
layers.extend([
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
),
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=(feat_out if self._sampling_num == 1 else
conv_channels),
kernel_size=1,
stride=1,
padding=0,
groups=1,
),
])
in_channels = conv_channels
layers.append(activation)
for i in range(self._sampling_num - 1):
layers.extend([
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
),
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=(feat_out if self._sampling_num == i +
2 else conv_channels),
kernel_size=1,
stride=1,
padding=0,
groups=1,
),
])
layers.append(activation)
in_channels = conv_channels
else:
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
if subsampling in ["dw_striding", "striding"]:
in_length = torch.tensor(feat_in, dtype=torch.float)
out_length = calc_length(
lengths=in_length,
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
self.out = torch.nn.Linear(conv_channels * int(out_length),
feat_out)
self.conv2d_subsampling = True
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
self.out = None
self.conv2d_subsampling = False
else:
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
self.conv = torch.nn.Sequential(*layers)
def get_sampling_frames(self):
return [1, self.subsampling_factor]
def get_streaming_cache_size(self):
return [0, self.subsampling_factor + 1]
def forward(self, x, mask):
"""
Forward method for NeMo subsampling.
Args:
x[Batch, Time, Filters]: torch.Tensor
input tensor
x_mask: torch.Tensor
input mask
Returns:
x: torch.Tensor
Resulting tensor from subsampling (B, T //
time_reduction_factor, feat_out)
pad_mask: torch.Tensor
tensor of padded hidden state sequences (B, 1, T //
time_reduction_factor)
"""
x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
# split inputs if chunking_factor is set
if (self.subsampling_conv_chunking_factor != -1
and self.conv2d_subsampling):
if self.subsampling_conv_chunking_factor == 1:
# if subsampling_conv_chunking_factor is 1, we split only
# if needed.
# avoiding a bug / feature limiting indexing of tensors
# to 2**31.
# see https://github.com/pytorch/pytorch/issues/80020
x_ceil = (2**31 / self._conv_channels * self._stride *
self._stride)
need_to_split = torch.numel(x) > x_ceil
else:
# if subsampling_conv_chunking_factor > 1 we always split
need_to_split = True
if need_to_split:
x, success = self.conv_split_by_batch(x)
if not success: # if unable to split by batch, try by channel
if self._subsampling == "dw_striding":
x = self.conv_split_by_channel(x)
else:
x = self.conv(x) # try anyway
else:
x = self.conv(x)
else:
x = self.conv(x)
# Flatten Channel and Frequency Axes
if self.conv2d_subsampling:
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
# Transpose to Channel Last mode
else:
x = x.transpose(1, 2)
if mask is None:
return x, None
max_audio_length = x.shape[1]
feature_lens = mask.sum(1)
padding_length = torch.ceil(feature_lens / self.subsampling_factor)
if self.is_causal and self.subsampling_causal_cond:
feature_lens_remainder = feature_lens % self.subsampling_factor
padding_length[feature_lens_remainder != 1] += 1
pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
padding_length.size(0), -1) < padding_length.unsqueeze(1)
return x, pad_mask.unsqueeze(1)
def reset_parameters(self):
# initialize weights
if self._subsampling == "dw_striding":
with torch.no_grad():
# init conv
scale = 1.0 / self._kernel_size
dw_max = (self._kernel_size**2)**-0.5
pw_max = self._conv_channels**-0.5
torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
for idx in range(2, len(self.conv), 3):
torch.nn.init.uniform_(self.conv[idx].weight, -dw_max,
dw_max)
torch.nn.init.uniform_(self.conv[idx].bias, -dw_max,
dw_max)
torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max,
pw_max)
torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max,
pw_max)
# init fc (80 * 64 = 5120 from https://github.com/kssteven418/
# Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
# src/models/conformer_encoder.py#L487
fc_scale = (self._feat_out * self._feat_in /
self._sampling_num)**-0.5
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
def conv_split_by_batch(self, x):
"""Tries to split input by batch, run conv and concat results"""
b, _, _, _ = x.size()
if b == 1: # can't split if batch size is 1
return x, False
if self.subsampling_conv_chunking_factor > 1:
cf = self.subsampling_conv_chunking_factor
else:
# avoiding a bug / feature limiting indexing of tensors to 2**31
# see https://github.com/pytorch/pytorch/issues/80020
x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
cf = 2**p
new_batch_size = b // cf
if new_batch_size == 0: # input is too big
return x, False
return (
torch.cat([
self.conv(chunk)
for chunk in torch.split(x, new_batch_size, 0)
]),
True,
)
def conv_split_by_channel(self, x):
"""For dw convs, tries to split input by time, run conv and concat
results"""
x = self.conv[0](x) # full conv2D
x = self.conv[1](x) # activation
for i in range(self._sampling_num - 1):
_, c, t, _ = x.size()
if self.subsampling_conv_chunking_factor > 1:
cf = self.subsampling_conv_chunking_factor
else:
# avoiding a bug / feature limiting indexing of tensors
# to 2**31
# see https://github.com/pytorch/pytorch/issues/80020
p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
cf = 2**p
new_c = int(c // cf)
if new_c == 0:
new_c = 1
new_t = int(t // cf)
if new_t == 0:
new_t = 1
x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c,
x) # conv2D, depthwise
# splitting pointwise convs by time
x = torch.cat(
[
self.conv[i * 3 + 3](chunk)
for chunk in torch.split(x, new_t, 2)
],
2,
) # conv2D, pointwise
x = self.conv[i * 3 + 4](x) # activation
return x
def channel_chunked_conv(self, conv, chunk_size, x):
"""Performs channel chunked convolution"""
ind = 0
out_chunks = []
for chunk in torch.split(x, chunk_size, 1):
step = chunk.size()[1]
if self.is_causal:
chunk = nn.functional.pad(
chunk,
pad=(
self._kernel_size - 1,
self._stride - 1,
self._kernel_size - 1,
self._stride - 1,
),
)
ch_out = nn.functional.conv2d(
chunk,
conv.weight[ind:ind + step, :, :, :],
bias=conv.bias[ind:ind + step],
stride=self._stride,
padding=0,
groups=step,
)
else:
ch_out = nn.functional.conv2d(
chunk,
conv.weight[ind:ind + step, :, :, :],
bias=conv.bias[ind:ind + step],
stride=self._stride,
padding=self._left_padding,
groups=step,
)
out_chunks.append(ch_out)
ind += step
return torch.cat(out_chunks, 1)
def change_subsampling_conv_chunking_factor(
self, subsampling_conv_chunking_factor: int):
if (subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0):
raise ValueError(
"subsampling_conv_chunking_factor should be -1, 1, or a "\
"power of 2"
)
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def calc_length(lengths,
all_paddings,
kernel_size,
stride,
ceil_mode,
repeat_num=1):
"""Calculates the output length of a Tensor passed through a convolution or
max pooling layer"""
add_pad: float = all_paddings - kernel_size
one: float = 1.0
for i in range(repeat_num):
lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) +
one)
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
return lengths.to(dtype=torch.int)
#### multihead attention starts here
class AttModule(nn.Module):
"""Attention abstraction module"""
def __init__(self):
super().__init__()
self.export_mode = False
def set_export(self, mode=True):
"""set the export mode"""
self.export_mode = mode
def forward(
self,
x: Tensor,
memory: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None,
att_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""AttModule forward
Args:
x: torch.Tensor
input tensor.
memory: torch.Tensor, optional
memory tensor.
pos_emb: torch.Tensor, optional
positional encoder embedding.
att_mask: torch.Tensor, optional
attention mask tensor.
"""
return x, memory, pos_emb, att_mask
class AttBlock(Block, AttModule):
"""Attention Block module to support both Attention and Block module."""
def memory_dims(self, max_len=False):
"""memory dimensions"""
return (1, self.input_size)
def masked_softmax(
scores,
mask: Optional[Tensor],
):
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -torch.inf)
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
return attn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer with optional relative position embedding
and GLU.
Args:
n_head: int
the number of heads.
n_feat: int
input size features.
dropout_rate: float
dropout rate.
use_LN: bool
apply layer norm or not
dropout_at_output: bool
whether to apply dropout at output
attention_inner_dim: int, optional
the attention dimension used in the class,
it can be different from the input dimension n_feat.
default: -1 (equal to n_feat).
use_pt_scaled_dot_product_attention: bool, optional
if set True, use pytorch scaled dot product attention in training.
NOTE: this will NOT be used in ONNX decoding due to a lack of
support. In that case, we use the original attention
implementation, which shows no regression.
default: False.
n_value: int, optional
if set to values other than -1, use a different dimension for
value. With the default value (i.e. -1), it is backward compatible.
group_size: int, optional. must divide `n_head`
if group_size > 1: GQA
if group_size = 1: MHA
if group_size = n_head: MQA
"""
inv_sqrt_d_k: torch.jit.Final[float]
h: torch.jit.Final[int]
h_k: torch.jit.Final[int]
g: torch.jit.Final[int]
def __init__(
self,
n_head,
n_feat,
dropout_rate,
attention_inner_dim=-1,
glu_type="swish",
bias_in_glu=True,
use_pt_scaled_dot_product_attention=False,
n_value=-1,
group_size: int = 1,
):
super().__init__()
if n_value == -1:
n_value = n_feat
if attention_inner_dim == -1:
attention_inner_dim = n_feat
assert attention_inner_dim % n_head == 0
# We assume d_v always equals d_k
self.d_k = attention_inner_dim // n_head
self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
self.h = n_head
assert n_head % group_size == 0, "group_size must divide n_head"
self.g = group_size
self.h_k = n_head // group_size
self.linear_q = nn.Linear(n_feat, attention_inner_dim)
self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
self.attn = torch.jit.Attribute(None, Optional[Tensor])
self.dropout = nn.Dropout(p=dropout_rate)
self.dropout_rate = dropout_rate
self.use_pt_scaled_dot_product_attention = (
use_pt_scaled_dot_product_attention)
if use_pt_scaled_dot_product_attention and group_size > 1:
raise ValueError("Cannot use PT Scaled Attention with GQA")
# Torchscript eager quantization. Note that these functions below are
# NOOPs and have very little impact on performance unless quantization
# is enabled.
self.quant_q = torch.ao.quantization.QuantStub()
self.quant_x = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
self.ffunc = torch.ao.nn.quantized.FloatFunctional()
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_k: Tensor,
pos_v: Tensor,
mask: Optional[Tensor],
relative_attention_bias: Optional[Tensor] = None,
):
"""Compute 'Scaled Dot Product Attention'.
Args:
query: torch.Tensor
query tensor (batch, time1, size)
key: torch.Tensor
key tensor (batch, time2, size)
value: torch.Tensor
value tensor (batch, time1, size)
pos_k: torch.Tensor
key tensor used for relative positional embedding.
pos_v: torch.Tensor
value tensor used for relative positional embedding.
mask: torch.Tensor
mask tensor (batch, time1, time2)
relative_attention_bias: torch.Tensor
bias added to attention logits w.r.t. relative positions
(1, n_head, time1, time2)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h,
self.d_k) # (b, t, d)
k = self.linear_k(key).view(n_batch, -1, self.h_k,
self.d_k) # (b, t, d)
v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention
and not torch.jit.is_scripting() else q.transpose(1, 2) *
self.inv_sqrt_d_k)
k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
if (self.use_pt_scaled_dot_product_attention
and not torch.jit.is_scripting()):
attn_mask = None
if mask is not None:
mask = mask.unsqueeze(1)
if relative_attention_bias is not None:
attn_mask = mask + relative_attention_bias
else:
attn_mask = mask
if mask.dtype != q.dtype:
attn_mask = attn_mask.to(q.dtype)
with torch.backends.cuda.sdp_kernel(enable_flash=True,
enable_math=True,
enable_mem_efficient=True):
x = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.dropout_rate,
)
else:
if self.h != self.h_k:
q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
else:
A = torch.matmul(q, k.transpose(-2, -1))
if pos_k is not None:
if self.h != self.h_k:
B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
else:
reshape_q = (q.contiguous().view(n_batch * self.h, -1,
self.d_k).transpose(0, 1)
) # (t1,nh,dk)
B = torch.matmul(reshape_q,
pos_k.transpose(-2,
-1)) # pos_k: (t1,dk,t2)
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0),
pos_k.size(1))
scores = A + B
else:
scores = A
if relative_attention_bias is not None:
scores = scores + relative_attention_bias
attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
self.attn = attn
p_attn = self.dropout(attn)
x = torch.matmul(p_attn.to(v.dtype),
v) # (batch, head, time1, d_k)
if pos_v is not None:
reshape_attn = (p_attn.contiguous().view(
n_batch * self.h, pos_v.size(0),
pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2)
attn_v = (torch.matmul(reshape_attn, pos_v).transpose(
0, 1).contiguous().view(n_batch, self.h, pos_v.size(0),
self.d_k))
x = x + attn_v
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h_k * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def validate_checkpointing_config(activation_checkpointing):
"""validate activation checkpointing configuration"""
if isinstance(activation_checkpointing, str):
assert activation_checkpointing in (
"",
"checkpoint",
"offload",
), "activation_checkpointing has to be a dict or a str in "\
"('', 'checkpoint', 'offload')."
elif isinstance(activation_checkpointing, dict):
assert activation_checkpointing.get("module", "transformer") in (
"transformer",
"attention",
), "module in activation_checkpointing has to be in "\
"('transformer', 'attention')."
else:
raise ValueError("activation_checkpointing has to be a str"\
" or dict.")
def embedding_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict], ) -> Callable:
"""return encoder embedding activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
enabled = activation_checkpointing.get("embed", False)
if enabled:
offloading = activation_checkpointing.get("offload", False)
if offloading:
return offload_wrapper
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", False) else CheckpointImpl.NO_REENTRANT)
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
def attn_checkpointing(activation_checkpointing: Union[str, Dict],
i) -> Union[str, Dict]:
"""return activation checkpointing config for attention layer"""
if isinstance(activation_checkpointing, str):
return ""
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
checkpointing_interval = activation_checkpointing.get("interval", 1)
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
return activation_checkpointing
return ""
raise ValueError("Invalid activation_checkpointing config")
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential"""
@torch.jit.ignore
def forward(self, *args):
"""Forward method implementation."""
for m in self:
args = m(*args)
return args
def repeat(repeat_num, module_gen_fn):
"""repeat module N times
:param int repeat_num: repeat time
:param function module_gen_fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
def get_offset(input_layer: str, time_reduction: int):
"""Get an offset. We will use the offset for determining #frames of a
subsampled feature.
Args:
input_layer (str): Type of an input layer
time_reduction (int): time reduction factor for downsampling a feature
Returns:
int: offset
"""
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
return 3
if input_layer in ("conv2d", ) and time_reduction == 6:
return 1
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
return 7
return 0
def unfold_tensor(xs_pad, max_seq_len):
"""
For a given tensor with shape of (N, T, D), if sequence length T is
longer than max_seq_len, this function unfold it to a
(NT', max_seq_len, D) where T' is T // max_seq_len.
Args:
xs_pad: N, T, D
"""
_, _, D = xs_pad.shape
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
# N x D x 1 x T => N x (D x max_seq_len) x T'
xs_pad = F.unfold(
xs_pad[..., None, :],
kernel_size=(1, max_seq_len),
stride=(1, max_seq_len),
)
new_bsz, _, slen = xs_pad.shape
# N x D x max_seq_len x T'
xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
# N x T' x max_seq_len x D
xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
# NT' x max_seq_len x D
xs_pad = xs_pad.view(-1, max_seq_len, D)
return xs_pad
......@@ -182,6 +182,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
......
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Siglip model configuration"""
import math
import os
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (ModelOutput, add_start_docstrings,
add_start_docstrings_to_model_forward, logging,
replace_return_docstrings)
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend
logger = logging.get_logger(__name__)
SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/siglip-base-patch16-224":
"https://huggingface.co/google/siglip-base-patch16-224/"\
"resolve/main/config.json",
}
class SiglipTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`SiglipTextModel`]. It is used to instantiate a Siglip text encoder
according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the text encoder of the Siglip [google/
siglip-base-patch16-224](https://huggingface.co/google/siglip-base
-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Siglip text model. Defines the number of
different tokens that can be represented by the `inputs_ids`
passed when calling [`SiglipModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer
in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the
Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 64):
The maximum sequence length that this model might ever be used
with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_act (`str` or `function`, *optional*, defaults to
`"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the
encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
Example:
```python
>>> from transformers import SiglipTextConfig, SiglipTextModel
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224
style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the
google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_text_model"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=64,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
# This differs from `CLIPTokenizer`'s default and from openai/siglip
# See https://github.com/huggingface/transformers/pull/24773#
# issuecomment-1632287538
pad_token_id=1,
bos_token_id=49406,
eos_token_id=49407,
_flash_attn_2_enabled=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self._flash_attn_2_enabled = _flash_attn_2_enabled
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
os.PathLike],
**kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["text_config"]
if "model_type" in config_dict and hasattr(
cls,
"model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors.", config_dict['model_type'],
cls.model_type)
return cls.from_dict(config_dict, **kwargs)
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the
model architecture. Instantiating a configuration with the defaults will
yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/
siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer
in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the
Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to
`"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the
encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and
`"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224
style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the
google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
_flash_attn_2_enabled=True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self._flash_attn_2_enabled = _flash_attn_2_enabled
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
os.PathLike],
**kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(
cls,
"model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
"You are using a model of type %s to "
"instantiate a model of type %s. This is not"
" supported for all configurations of models and can yield"
" errors.", config_dict['model_type'], cls.model_type)
return cls.from_dict(config_dict, **kwargs)
class SiglipConfig(PretrainedConfig):
r"""
[`SiglipConfig`] is the configuration class to store the configuration of a
[`SiglipModel`]. It is used to instantiate a Siglip model according to the
specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the Siglip [google/siglip-base-patch16-224](
https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to
control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize
[`SiglipTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize
[`SiglipVisionConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import SiglipConfig, SiglipModel
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224
style configuration
>>> configuration = SiglipConfig()
>>> # Initializing a SiglipModel (with random weights) from the
google/siglip-base-patch16-224 style configuration
>>> model = SiglipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig
and a SiglipVisionConfig
>>> from transformers import SiglipTextConfig, SiglipVisionConfig
>>> # Initializing a SiglipText and SiglipVision configuration
>>> config_text = SiglipTextConfig()
>>> config_vision = SiglipVisionConfig()
>>> config = SiglipConfig.from_text_vision_configs(config_text,
config_vision)
```"""
model_type = "siglip"
def __init__(self, text_config=None, vision_config=None, **kwargs):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info(
"`text_config` is `None`. Initializing the `SiglipTextConfig`"
" with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. initializing the "
"`SiglipVisionConfig` with default values.")
self.text_config = SiglipTextConfig(**text_config)
self.vision_config = SiglipVisionConfig(**vision_config)
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(cls, text_config: SiglipTextConfig,
vision_config: SiglipVisionConfig, **kwargs):
r"""
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text
model configuration and siglip vision
model configuration.
Returns:
[`SiglipConfig`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
**kwargs)
# coding=utf-8
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Siglip model."""
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/siglip-base-patch16-224",
# See all SigLIP models at https://huggingface.co/models?filter=siglip
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official
# releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/
# truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std) # noqa
u = norm_cdf((b - mean) / std) # noqa
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if tensor.dtype in [torch.float16, torch.bfloat16]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
tensor.erfinv_()
tensor = tensor.to(og_dtype)
else:
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
if tensor.dtype == torch.float16:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor = tensor.to(torch.float32)
tensor.clamp_(min=a, max=b)
tensor = tensor.to(torch.float16)
else:
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where
the bounds [a, b] are applied when sampling the normal distribution with
mean=0, std=1.0 and the result is subsequently scaled and shifted by the
mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with
# CLIP->Siglip
class SiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings
of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
*optional* returned when model is initialized with
`with_projection=True`):
The image embeddings obtained by applying the projection layer to
the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size,
sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the
model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when
`output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings,
if the model has an embedding layer, + one for the output of each
layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the
optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when
`output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape
`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with
# CLIP->Siglip
class SiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the
last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
*optional* returned when model is initialized with
`with_projection=True`):
The text embeddings obtained by applying the projection layer to
model.
the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size,
sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when
`output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the
embeddings, if the model has an embedding layer, + one for the
output of each layer) of shape `(batch_size, sequence_length,
hidden_size)`.
Hidden-states of the model at the output of each layer plus the
optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when
`output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape
`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute
the weighted average in the self-attention heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with
# CLIP->Siglip
class SiglipOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when
`return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size,
text_batch_size)`):
The scaled dot product scores between `image_embeds` and
`text_embeds`. This represents the image-text similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size,
image_batch_size)`):
The scaled dot product scores between `text_embeds` and
`image_embeds`. This represents the text-image similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to
the pooled output of [`SiglipTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to
the pooled output of [`SiglipVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"
] else getattr(self, k).to_tuple()
for k in self.keys())
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(self, pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
batch_size = pixel_values.size(0)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, \
max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
1 / self.num_patches_per_side)
position_ids = torch.full(
size=(
batch_size,
max_nb_patches_h * max_nb_patches_w,
),
fill_value=0,
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h,
nb_patches_h)
fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w,
nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
boundaries,
right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w,
boundaries,
right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with
# CLIP->Siglip
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings,
embed_dim)
# position_ids (1, len position emb) is contiguous in memory and
# exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[
-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`:"
f" {self.embed_dim} and `num_heads`: {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len,
k_v_seq_len):
raise ValueError(
f"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(f"Attention mask should be of size "
f"{(batch_size, 1, q_len, k_v_seq_len)}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len,
self.head_dim):
raise ValueError(
f"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipFlashAttention2(SiglipAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as
the weights of the module stays untouched. The only required change would
be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any
of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Hack to make sure we don't use a causal mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(
kv_seq_len, self.layer_idx)
# TODO: These transpose are quite inefficient but Flash Attention
# requires the layout [batch_size, sequence_length, num_heads,
# head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training
# stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct
# dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to
# not cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
"The input hidden states seems to be silently casted in "
"float32, this might be related to the fact you have upcasted "
"embedding or layer norm layers in float32. We will cast "
f"back the input in {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate)
attn_output = attn_output.reshape(bsz, q_len,
self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
def _flash_attention_forward(self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None):
"""
Calls the forward method of Flash Attention - if the input hidden
states contain at least one padding token first unpad the input,
then computes the attention scores and pad the final attention
scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size
`(batch_size, seq_len)` where 0 stands for the position
of padding tokens and 1 for the position of non-padding
tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 /
sqrt(head_dim)
"""
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input # noqa
# TODO: Remove the `query_length != 1` check once Flash Attention for
# RoCm is bumped to 2.1. For details, please see the comment in
# LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, \
max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask,
query_length)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
query_length)
else:
attn_output = flash_attn_func(query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
query_length):
from flash_attn.bert_padding import index_first_axis, unpad_input
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \
unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with
# CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = (SiglipAttention(config) if
not getattr(config, "_flash_attn_2_enabled", False)
else SiglipFlashAttention2(config))
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where
padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under returned tensors for
more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (attn_weights, )
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface
for downloading and loading pretrained models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (self.config.vision_config.hidden_size if isinstance(
self.config, SiglipConfig) else self.config.hidden_size)
nn.init.normal_(module.position_embedding.weight,
std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.normal_(module.q_proj.weight)
nn.init.normal_(module.k_proj.weight)
nn.init.normal_(module.v_proj.weight)
nn.init.normal_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.normal_(module.fc1.weight)
nn.init.normal_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.normal_(module.probe.data)
nn.init.normal_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.tensor(0.0)
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass
documentation for the generic methods the library implements for all
its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/
stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation
for all matter related to general usage and behavior.
Parameters:
config ([`SiglipConfig`]): Model configuration class with all the
parameters of the model.
Initializing with a config file does not load the weights
associated with the model, only the configuration. Check out
the [`~PreTrainedModel.from_pretrained`] method to load the
model weights.
"""
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)
`):
Indices of input sequence tokens in the vocabulary. Padding will
be ignored by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details. [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size,
sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask
values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size,
sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position
embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention
layers. See `attentions` under returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a
plain tuple.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size,
num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you
provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention
layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a
plain tuple.
"""
SIGLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size,
sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding
will be ignored by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details. [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`
, *optional*):
Mask to avoid performing attention on padding token indices. Mask
values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size,
sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position
embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size,
num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you
provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`]
for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention
layers. See `attentions` under returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a
plain tuple.
"""
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with
# CLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers`
self attention layers. Each layer is a [`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation.
This is useful if you want more control over how to convert
`input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size,
sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under returned tensors for
more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a
plain tuple.
"""
output_attentions = output_attentions if output_attentions \
is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else \
self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1], )
if output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
if not return_dict:
return tuple(
v for v in [hidden_states, encoder_states, all_attentions]
if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions)
class SiglipTextTransformer(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, embed_dim)
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions \
is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states \
is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else \
self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
# note: SigLIP's text model does not use a causal mask, unlike the
# original CLIP model.
# expand attention_mask
if attention_mask is not None:
# [batch_size, seq_len] ->
# [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(
attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# Assuming "sticky" EOS tokenization, last token is always EOS.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""The text model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, SiglipTextModel
>>> model = SiglipTextModel.
from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.
from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length"
as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"],
padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token)
states
```"""
return_dict = return_dict if return_dict is not None else \
self.config.use_return_dict
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
config_class=SiglipVisionConfig)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None\
else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending
# to the whole sequence), avoiding passing the attention_mask, which
# is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
attention_mask = None
else:
attention_mask = (_prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype)
if not self.config._flash_attn_2_enabled else
patch_attention_mask)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(
hidden_state=last_hidden_state,
attention_mask=patch_attention_mask,
)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state, attention_mask):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(query=probe,
key=hidden_state,
value=hidden_state,
key_padding_mask=~attention_mask)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(
"""The vision model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
config_class=SiglipVisionConfig)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SiglipVisionModel
>>> model = SiglipVisionModel.from_pretrained(
"google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained(
"google/siglip-base-patch16-224")
>>> url =
"http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@add_start_docstrings(SIGLIP_START_DOCSTRING)
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
if not isinstance(config.text_config, SiglipTextConfig):
raise ValueError("config.text_config is expected to be of type "
f"SiglipTextConfig but is of type"
f" {type(config.text_config)}.")
if not isinstance(config.vision_config, SiglipVisionConfig):
raise ValueError("config.vision_config is expected to be of type "
"SiglipVisionConfig but is of type"
f" {type(config.vision_config)}.")
text_config = config.text_config
vision_config = config.vision_config
self.text_model = SiglipTextTransformer(text_config)
self.vision_model = SiglipVisionTransformer(vision_config)
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size,
output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output
of [`SiglipTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained(
"google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained(
"google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's
how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"],
padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
# Use SigLIP model's config for some fields (if specified) instead
# of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None\
else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size,
output_dim`): The image embeddings obtained by applying the
projection layer to the pooled output of [`SiglipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained(
"google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```"""
# Use SiglipModel's config for some fields (if specified) instead
# of those of vision & text components.
output_attentions = output_attentions if output_attentions \
is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else \
self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = vision_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SiglipOutput,
config_class=SiglipConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained(
"google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> # important: we pass `padding=max_length` since the model was
trained with this
>>> inputs = processor(text=texts, images=image,
padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the
probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```"""
# Use SigLIP model's config for some fields (if specified) instead of
# those of vision & text components.
output_attentions = output_attentions if output_attentions \
is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else \
self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / image_embeds.norm(
p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t(
)) * self.logit_scale.exp() + self.logit_bias
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
raise NotImplementedError("SigLIP loss to be implemented")
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds,
image_embeds, text_outputs, vision_outputs)
return ((loss, ) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
siglip_vision_config = {
"hidden_size": 1152,
"image_size": 448,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
}
# Detect attention implementation.
attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if attn_backend != _Backend.FLASH_ATTN:
_flash_attn_2_enabled = False
model_config = SiglipVisionConfig(
**siglip_vision_config,
_flash_attn_2_enabled=_flash_attn_2_enabled,
**kwargs)
vision_model = SiglipVisionModel(model_config).vision_model
return vision_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment