Commit 112bf76b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1826 canceled with stages
import torch
import torch.nn as nn
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
from vita.util.s2wrapper import forward as multiscale_forward
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = -2
if not delay_load:
self.load_model()
else:
self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.image_processor.crop_size = self.image_processor.size
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class SiglipVisionTowerS2(SiglipVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
self.s2_scales = getattr(args, "s2_scales", "384,768,1152")
self.s2_scales = list(map(int, self.s2_scales.split(",")))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]
super().__init__(vision_tower, args, delay_load)
self.multiscale_forward = multiscale_forward
if not delay_load:
self.image_processor.size["height"] = self.image_processor.size[
"width"
] = self.s2_image_size
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
"width"
] = self.s2_image_size
def load_model(self):
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.image_processor.crop_size = self.image_processor.size
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.image_processor.size["height"] = self.image_processor.size[
"width"
] = self.s2_image_size
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
"width"
] = self.s2_image_size
self.is_loaded = True
@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(
self.forward_feature,
image.unsqueeze(0),
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size,
)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(
self.forward_feature,
images,
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size,
)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
class CNNAdapter(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 5,
):
super().__init__()
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu1 = nn.ReLU()
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu2 = nn.ReLU()
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
def forward(self, x, mask_pad):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x = x.transpose(1, 2) # B, channels, T
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
x = self.left_padding1(x)
x = self.conv1d1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.left_padding2(x)
x = self.conv1d2(x)
x = self.bn2(x)
x = self.relu2(x)
x = x.transpose(1, 2)
x = self.project(x)
return x, mask_pad
class LinearAdapter(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
):
super().__init__()
self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
def forward(self, x, mask_pad):
return self.adpter(x), mask_pad
class CNNSubsampling(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 5,
activation_func: str = "relu",
norm: str = "batch",
):
super().__init__()
if enc_out_dim * 4 < llm_embed_dim:
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu1 = nn.ReLU()
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu2 = nn.ReLU()
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
self.cnn_num = 2
else:
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
if norm == "batch":
self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
elif norm == "layer":
self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
if activation_func == "gelu":
self.relu2 = nn.GELU()
else:
self.relu2 = nn.ReLU()
self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
self.cnn_num = 1
def forward(self, x, mask_pad):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x = x.transpose(1, 2) # B, channels, T
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.cnn_num == 2:
x = self.left_padding1(x)
x = self.conv1d1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.left_padding2(x)
x = self.conv1d2(x)
if isinstance(self.bn2, nn.LayerNorm):
x = x.transpose(1, 2)
x = self.bn2(x)
if isinstance(self.bn2, nn.LayerNorm):
x = x.transpose(1, 2)
x = self.relu2(x)
x = x.transpose(1, 2)
x = self.project(x)
return x, mask_pad[:, :, 0::2]
import numpy as np
import torch
import json
import math
class GlobalCMVN(torch.nn.Module):
def __init__(self, mean: torch.Tensor, istd: torch.Tensor, norm_var: bool = True):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super().__init__()
assert mean.shape == istd.shape
self.norm_var = norm_var
# The buffer can be accessed from this module using self.mean
self.register_buffer("mean", mean)
self.register_buffer("istd", istd)
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
def load_cmvn_json(json_cmvn_file):
with open(json_cmvn_file) as f:
cmvn_json = json.load(f)
avg = cmvn_json["mean_stat"]
var = cmvn_json["var_stat"]
count = cmvn_json["frame_num"]
for i in range(len(avg)):
avg[i] /= count
var[i] = var[i] / count - avg[i] * avg[i]
if var[i] < 1.0e-20:
var[i] = 1.0e-20
var[i] = 1.0 / math.sqrt(var[i])
cmvn = np.array([avg, var])
return cmvn
def load_cmvn_kaldi(kaldi_cmvn_file):
avg = []
var = []
with open(kaldi_cmvn_file, "r") as file:
# kaldi binary file start with '\0B'
if file.read(2) == "\0B":
logging.error(
"kaldi cmvn binary file is not supported, please "
)
sys.exit(1)
file.seek(0)
arr = file.read().split()
assert arr[0] == "["
assert arr[-2] == "0"
assert arr[-1] == "]"
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
avg.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
var.append(float(arr[i]))
for i in range(len(avg)):
avg[i] /= count
var[i] = var[i] / count - avg[i] * avg[i]
if var[i] < 1.0e-20:
var[i] = 1.0e-20
var[i] = 1.0 / math.sqrt(var[i])
cmvn = np.array([avg, var])
return cmvn
def load_cmvn(filename, is_json):
if is_json:
file = load_cmvn_json(filename)
else:
file = load_cmvn_kaldi(filename)
return file[0], file[1]
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter
from .cmvn import GlobalCMVN, load_cmvn
from .module.encoder.encoder import whaleEncoder
class audioEncoderProcessor:
def __init__(
self,
dataset_conf: dict = None,
):
self.dataset_conf = dataset_conf
def process(self, wav_path):
try:
print("#################", wav_path)
waveform, sample_rate = torchaudio.load(wav_path)
except Exception as e:
print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!")
if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]:
# sample_rate = self.dataset_conf['resample_conf']['resample_rate']
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"]
)(waveform)
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(
waveform,
num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"],
frame_length=self.dataset_conf["fbank_conf"]["frame_length"],
frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"],
dither=self.dataset_conf["fbank_conf"]["dither"],
energy_floor=0.0,
sample_frequency=sample_rate,
)
attn_mask = torch.ones(mat.shape[0])
attn_mask = attn_mask[2::2][2::2][0::2]
return mat, attn_mask.shape[0]
class audioEncoder(torch.nn.Module):
def __init__(
self,
encoder: torch.nn.Module,
llm_path: str,
freeze_llm: bool = True,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 3,
IGNORE_ID: int = -100,
adpter_type: str = "cnn",
add_audio_bos_eos: bool = False,
task_num: int = 10,
task_before_audio: bool = False,
task_type: str = "prompt",
freeze_encoder: bool = False,
freeze_adpter: bool = False,
activation_func: str = "relu",
norm: str = "batch",
chat_template=None,
):
super().__init__()
self.encoder = encoder
self.enc_out_dim = enc_out_dim
self.llm_embed_dim = llm_embed_dim
self.IGNORE_ID = IGNORE_ID
self.add_audio_bos_eos = add_audio_bos_eos
self.task_before_audio = task_before_audio
self.task_type = task_type
self.freeze_encoder = freeze_encoder
self.freeze_adpter = freeze_adpter
if adpter_type == "cnn":
self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
elif adpter_type == "linear":
self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
elif adpter_type == "subsampling":
self.adpter = CNNSubsampling(
enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm
)
if self.freeze_encoder:
self.encoder.eval()
for (name, param) in self.encoder.named_parameters():
param.requires_grad = False
if self.freeze_adpter:
self.adpter.eval()
for (name, param) in self.adpter.named_parameters():
param.requires_grad = False
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Dict[str, Optional[torch.Tensor]]:
speech = speech.to(next(self.parameters()).dtype)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D
attention_mask = encoder_mask.squeeze(1) # B, T
assert inputs_embeds.size(1) == attention_mask.size(1)
# audio bos/eos
if self.add_audio_bos_eos:
inputs_embeds, attention_mask, target = self._add_bos_eos(
"audio", "/audio", inputs_embeds, attention_mask, target
)
outputs = {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
}
return outputs
def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
B = len(inputs_embeds)
bos_embed = self.task_embeddings(
torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device)
) # B, 1, D
eos_embed = self.task_embeddings(
torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device)
) # B, 1, D
bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2
bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
if target is not None:
target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D
return inputs_embeds, attention_mask, target
def init_model(configs):
if configs["cmvn_file"] is not None:
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
else:
global_cmvn = None
input_dim = configs["input_dim"]
encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"])
model = audioEncoder(encoder=encoder, **configs["model_conf"])
processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"])
model.audio_processor = processor
return model
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment