Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import numpy as np
from ibug.face_detection import RetinaFacePredictor
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self, device="cuda:0", model_name="resnet50"):
self.face_detector = RetinaFacePredictor(
device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name)
)
def __call__(self, video_frames):
landmarks = []
for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False)
if len(detected_faces) >= 1:
landmarks.append(np.reshape(detected_faces[0][:4], (2, 2)))
else:
landmarks.append(None)
return landmarks
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
import argparse
import os
parser = argparse.ArgumentParser(description="Merge labels")
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Specify the dataset used in the experiment",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Specify the subset of the dataset used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
required=True,
help="Directory of saved mouth patches or embeddings",
)
parser.add_argument(
"--groups",
type=int,
required=True,
help="Number of threads for parallel processing",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segments",
)
args = parser.parse_args()
dataset = args.dataset
subset = args.subset
seg_duration = args.seg_duration
# Check that there is more than one group
assert args.groups > 1, "There is no need to use this script for merging when --groups is 1."
# Create the filename template for label files
label_template = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}",
)
lines = []
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
assert os.path.exists(label_filename), f"{label_filename} does not exist."
with open(label_filename, "r") as file:
lines.extend(file.read().splitlines())
# Write the merged labels to a new file
dst_label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv",
)
with open(dst_label_filename, "w") as file:
file.write("\n".join(lines))
# Print the number of files and total duration in hours
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 # simplified from /3600./25.
print(f"The completed set has {len(lines)} files with a total of {total_duration:.2f} hours.")
# Remove the label files for each job index
print("** Remove the temporary label files **")
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
if os.path.exists(label_filename):
os.remove(label_filename)
print("** Finish **")
import argparse
import glob
import math
import os
import shutil
import warnings
import ffmpeg
from data.data_module import AVSRDataLoader
from tqdm import tqdm
from utils import save_vid_aud_txt, split_file
warnings.filterwarnings("ignore")
# Argument Parsing
parser = argparse.ArgumentParser(description="LRS3 Preprocessing")
parser.add_argument(
"--data-dir",
type=str,
help="The directory for sequence.",
)
parser.add_argument(
"--detector",
type=str,
default="retinaface",
help="Face detector used in the experiment.",
)
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of cropped-face dataset.",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Subset of the dataset used in the experiment.",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segment in seconds.",
)
parser.add_argument(
"--groups",
type=int,
default=1,
help="Number of threads to be used in parallel.",
)
parser.add_argument(
"--job-index",
type=int,
default=0,
help="Index to identify separate jobs (useful for parallel processing).",
)
args = parser.parse_args()
seg_duration = args.seg_duration
dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(modality="video", detector=args.detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25
seg_aud_len = seg_duration * 16000
label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1
else f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
)
os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created")
f = open(label_filename, "w")
# Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
if args.subset == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.subset, "**", "*.mp4"), recursive=True)
elif args.subset == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort()
else:
raise NotImplementedError
unit = math.ceil(len(filenames) * 1.0 / args.groups)
filenames = filenames[args.job_index * unit : (args.job_index + 1) * unit]
for data_filename in tqdm(filenames):
try:
video_data = vid_dataloader.load_data(data_filename)
audio_data = aud_dataloader.load_data(data_filename)
except UnboundLocalError:
continue
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
trim_vid_data, trim_aud_data = video_data, audio_data
text_line_list = open(data_filename[:-4] + ".txt", "r").read().splitlines()[0].split(" ")
text_line = " ".join(text_line_list[2:])
content = text_line.replace("}", "").replace("{", "")
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
continue
splitted = split_file(data_filename[:-4] + ".txt", max_frames=seg_vid_len)
for i in range(len(splitted)):
if len(splitted) == 1:
content, start, end, duration = splitted[i]
trim_vid_data, trim_aud_data = video_data, audio_data
else:
content, start, end, duration = splitted[i]
start_idx, end_idx = int(start * 25), int(end * 25)
try:
trim_vid_data, trim_aud_data = (
video_data[start_idx:end_idx],
audio_data[:, start_idx * 640 : end_idx * 640],
)
except TypeError:
continue
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}_{i:02d}.txt"
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
f.close()
tqdm
scikit-image
opencv-python
ffmpeg-python
## Face Recognition
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. You can install directly from github repositories or by using compressed files.
### Option 1. Install from github repositories
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
```Shell
git clone https://github.com/hhj1897/face_detection.git
cd face_detection
git lfs pull
pip install -e .
cd ..
```
### Option 2. Install by using compressed files
If you are experiencing over-quota issues for the above repositoies, you can download both packages [ibug.face_detection](https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip), unzip the files, and then run `pip install -e .` to install each package.
```Shell
wget https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip -O ./face_detection.zip
unzip -o ./face_detection.zip -d ./
cd face_detection
pip install -e .
cd ..
```
import os
import torchaudio
import torchvision
def split_file(filename, max_frames=600, fps=25.0):
lines = open(filename).read().splitlines()
flag = 0
stack = []
res = []
tmp = 0
start_timestamp = 0.0
threshold = max_frames / fps
for line in lines:
if "WORD START END ASDSCORE" in line:
flag = 1
continue
if flag:
word, start, end, score = line.split(" ")
start, end, score = float(start), float(end), float(score)
if end < tmp + threshold:
stack.append(word)
last_timestamp = end
else:
res.append([" ".join(stack), start_timestamp, last_timestamp, last_timestamp - start_timestamp])
tmp = start
start_timestamp = start
stack = [word]
if stack:
res.append([" ".join(stack), start_timestamp, end, end - start_timestamp])
return res
def save_vid_txt(dst_vid_filename, dst_txt_filename, trim_video_data, content, video_fps=25):
# -- save video
save2vid(dst_vid_filename, trim_video_data, video_fps)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save_vid_aud(
dst_vid_filename, dst_aud_filename, trim_vid_data, trim_aud_data, video_fps=25, audio_sample_rate=16000
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
def save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save2vid(filename, vid, frames_per_second):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchvision.io.write_video(filename, vid, frames_per_second)
def save2aud(filename, aud, sample_rate):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchaudio.save(filename, aud, sample_rate)
import logging
from argparse import ArgumentParser
import sentencepiece as spm
import torch
import torchaudio
from transforms import get_data_module
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
ckpt = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)["state_dict"]
model.load_state_dict(ckpt)
model.eval()
return model
def run_eval(model, data_module):
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][-1]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
return total_edit_distance / total_length
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--modality",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--root-dir",
type=str,
help="Root directory to LRS3 audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to sentencepiece model.",
required=True,
)
parser.add_argument(
"--checkpoint-path",
type=str,
help="Path to a checkpoint model.",
required=True,
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
run_eval(model, data_module)
if __name__ == "__main__":
cli_main()
import itertools
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["inputs", "input_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
if args.modality == "video":
self.frontend = video_resnet()
if args.modality == "audio":
self.frontend = audio_resnet()
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
# -- initialise
if args.pretrained_model_path:
ckpt = torch.load(args.pretrained_model_path, map_location=lambda storage, loc: storage)
tmp_ckpt = {
k.replace("encoder.frontend.", ""): v for k, v in ckpt.items() if k.startswith("encoder.frontend.")
}
self.frontend.load_state_dict(tmp_ckpt)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.frontend.parameters(), self.model.parameters()])),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
features = self.frontend(batch.inputs)
output, src_lengths, _, _ = self.model(
features, batch.input_lengths, prepended_targets, prepended_target_lengths
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import itertools
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.fusion import fusion_module
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
AVBatch = namedtuple("AVBatch", ["audios", "videos", "audio_lengths", "video_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class AVConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
self.audio_frontend = audio_resnet()
self.video_frontend = video_resnet()
self.fusion = fusion_module()
frontend_params = [self.video_frontend.parameters(), self.audio_frontend.parameters()]
fusion_params = [self.fusion.parameters()]
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.model.parameters()] + frontend_params + fusion_params)),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
video_features = self.video_frontend(batch.videos)
audio_features = self.audio_frontend(batch.audios)
output, src_lengths, _, _ = self.model(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device))
hypotheses = decoder(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths.to(self.device),
beam_width=20,
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import os
import torchaudio
import torchvision
from torch.utils.data import Dataset
def _load_list(args, *filenames):
output = []
length = []
for filename in filenames:
filepath = os.path.join(args.root_dir, "labels", filename)
for line in open(filepath).read().splitlines():
dataset, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2]
path = os.path.normpath(os.path.join(args.root_dir, dataset, rel_path[:-4] + ".mp4"))
length.append(int(input_length))
output.append(path)
return output, length
def load_video(path):
"""
rtype: torch, T x C x H x W
"""
vid = torchvision.io.read_video(path, pts_unit="sec", output_format="THWC")[0]
vid = vid.permute((0, 3, 1, 2))
return vid
def load_audio(path):
"""
rtype: torch, T x 1
"""
waveform, sample_rate = torchaudio.load(path, normalize=True)
return waveform.transpose(1, 0)
def load_transcript(path):
transcript_path = path.replace("video_seg", "text_seg")[:-4] + ".txt"
return open(transcript_path).read().splitlines()[0]
def load_item(path, modality):
if modality == "video":
return (load_video(path), load_transcript(path))
if modality == "audio":
return (load_audio(path), load_transcript(path))
if modality == "audiovisual":
return (load_audio(path), load_video(path), load_transcript(path))
class LRS3(Dataset):
def __init__(
self,
args,
subset: str = "train",
) -> None:
if subset is not None and subset not in ["train", "val", "test"]:
raise ValueError("When `subset` is not None, it must be one of ['train', 'val', 'test'].")
self.args = args
if subset == "train":
self.files, self.lengths = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val":
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test":
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n):
path = self.files[n]
return load_item(path, self.args.modality)
def __len__(self) -> int:
return len(self.files)
from torchaudio.prototype.models import conformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/prototype/models/rnnt.html#conformer_rnnt_model
def conformer_rnnt():
return conformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
time_reduction_stride=1,
conformer_input_dim=256,
conformer_ffn_dim=1024,
conformer_num_layers=16,
conformer_num_heads=4,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
num_symbols=1024,
symbol_embedding_dim=256,
num_lstm_layers=2,
lstm_hidden_dim=512,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-5,
lstm_dropout=0.3,
joiner_activation="tanh",
)
from torchaudio.models.rnnt import emformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base
def emformer_rnnt():
return emformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
num_symbols=1024,
segment_length=64,
right_context_length=0,
time_reduction_input_dim=128,
time_reduction_stride=1,
transformer_num_heads=4,
transformer_ffn_dim=2048,
transformer_num_layers=20,
transformer_dropout=0.1,
transformer_activation="gelu",
transformer_left_context_length=30,
transformer_max_memory_size=0,
transformer_weight_init_scale_strategy="depthwise",
transformer_tanh_on_mem=True,
symbol_embedding_dim=512,
num_lstm_layers=3,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-3,
lstm_dropout=0.3,
)
import torch
class FeedForwardModule(torch.nn.Module):
r"""Positionwise feed forward layer.
Args:
input_dim (int): input dimension.
hidden_dim (int): hidden dimension.
dropout (float, optional): dropout probability. (Default: 0.0)
"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.0) -> None:
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.LayerNorm(input_dim),
torch.nn.Linear(input_dim, hidden_dim, bias=True),
torch.nn.SiLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_dim, output_dim, bias=True),
torch.nn.Dropout(dropout),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""
Args:
input (torch.Tensor): with shape `(*, D)`.
Returns:
torch.Tensor: output, with shape `(*, D)`.
"""
return self.sequential(input)
def fusion_module(input_dim=1024, hidden_dim=3072, output_dim=512, dropout=0.1):
return FeedForwardModule(input_dim, hidden_dim, output_dim, dropout)
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv2d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(outplanes),
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="swish",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
):
super(ResNet, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
# -- auxiliary functions
def threeD_to_2D_tensor(x):
n_batch, n_channels, s_time, sx, sy = x.shape
x = x.transpose(1, 2)
return x.reshape(n_batch * s_time, n_channels, sx, sy)
class Conv3dResNet(nn.Module):
"""Conv3dResNet module"""
def __init__(self, backbone_type="resnet", relu_type="swish"):
"""__init__.
:param backbone_type: str, the type of a visual front-end.
:param relu_type: str, activation function used in an audio front-end.
"""
super(Conv3dResNet, self).__init__()
self.backbone_type = backbone_type
self.frontend_nout = 64
self.trunk = ResNet(
BasicBlock,
[2, 2, 2, 2],
relu_type=relu_type,
)
# -- frontend3D
if relu_type == "relu":
frontend_relu = nn.ReLU(True)
elif relu_type == "prelu":
frontend_relu = nn.PReLU(self.frontend_nout)
elif relu_type == "swish":
frontend_relu = nn.SiLU(inplace=True)
self.frontend3D = nn.Sequential(
nn.Conv3d(
in_channels=1,
out_channels=self.frontend_nout,
kernel_size=(5, 7, 7),
stride=(1, 2, 2),
padding=(2, 3, 3),
bias=False,
),
nn.BatchNorm3d(self.frontend_nout),
frontend_relu,
nn.MaxPool3d(
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
),
)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences.
"""
# -- include Channel dimension
xs_pad = xs_pad.transpose(2, 1)
B, C, T, H, W = xs_pad.size()
xs_pad = self.frontend3D(xs_pad)
Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W
xs_pad = threeD_to_2D_tensor(xs_pad)
xs_pad = self.trunk(xs_pad)
xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1))
return xs_pad
def video_resnet():
return Conv3dResNet()
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv1d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv1d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm1d(outplanes),
)
class BasicBlock1D(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="relu",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock1D, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm1d(planes)
# type of ReLU is an input option
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet1D(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
a_upsample_ratio=1,
):
"""__init__.
:param block: torch.nn.Module, class of blocks.
:param layers: List, customised layers in each block.
:param relu_type: str, type of activation function.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(ResNet1D, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.a_upsample_ratio = a_upsample_ratio
self.conv1 = nn.Conv1d(
in_channels=1,
out_channels=self.inplanes,
kernel_size=80,
stride=4,
padding=38,
bias=False,
)
self.bn1 = nn.BatchNorm1d(self.inplanes)
if relu_type == "relu":
self.relu = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu = nn.PReLU(num_parameters=self.inplanes)
elif relu_type == "swish":
self.relu = nn.SiLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool1d(
kernel_size=20 // self.a_upsample_ratio,
stride=20 // self.a_upsample_ratio,
)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
return x
class Conv1dResNet(nn.Module):
"""Conv1dResNet"""
def __init__(self, relu_type="swish", a_upsample_ratio=1):
"""__init__.
:param relu_type: str, Activation function used in an audio front-end.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(Conv1dResNet, self).__init__()
self.a_upsample_ratio = a_upsample_ratio
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
"""
B, T, C = xs_pad.size()
xs_pad = xs_pad[:, : T // 640 * 640, :]
xs_pad = xs_pad.transpose(1, 2)
xs_pad = self.trunk(xs_pad)
# -- from B x C x T to B x T x C
xs_pad = xs_pad.transpose(1, 2)
return xs_pad
def audio_resnet():
return Conv1dResNet()
import math
import torch
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
steps_per_epoch: int,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_epochs * steps_per_epoch
self.total_steps = total_epochs * steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.warmup_steps:
return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
else:
decay_steps = self.total_steps - self.warmup_steps
return [
0.5 * base_lr * (1 + math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps))
for base_lr in self.base_lrs
]
import logging
import os
from argparse import ArgumentParser
import sentencepiece as spm
from average_checkpoints import ensemble
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def get_trainer(args):
seed_everything(1)
checkpoint = ModelCheckpoint(
dirpath=os.path.join(args.exp_dir, args.exp_name) if args.exp_dir else None,
monitor="monitoring_step",
mode="max",
save_last=True,
filename="{epoch}",
save_top_k=10,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
lr_monitor,
]
return Trainer(
sync_batchnorm=True,
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
devices=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
return model
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--modality",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--root-dir",
type=str,
help="Root directory to LRS3 audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument(
"--exp-dir",
default="./exp",
type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--exp-name",
type=str,
help="Experiment name",
)
parser.add_argument(
"--num-nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=55,
type=int,
help="Number of epochs to train for. (Default: 55)",
)
parser.add_argument(
"--resume-from-checkpoint",
default=None,
type=str,
help="Path to the checkpoint to resume from",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to use debug level for logging",
)
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
trainer = get_trainer(args)
trainer.fit(model, data_module)
ensemble(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LRS3 pretrain and trainval.
- `[lrs3_path]` is the directory path for the LRS3 cropped face dataset.
Example:
python train_spm.py --lrs3-path [lrs3_path]
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
return [open(transcript_path).read().splitlines()[0].lower()]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--lrs3-path",
type=pathlib.Path,
help="Path to LRS3 datasets.",
required=True,
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.lrs3_path / "LRS3_text_seg16s"
splits = ["pretrain", "trainval"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment