"torchvision/transforms/v2/functional/_geometry.py" did not exist on "b030e9363eab6089cc580725ee703cf2f01f3765"
Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import numpy as np
from ibug.face_detection import RetinaFacePredictor
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self, device="cuda:0", model_name="resnet50"):
self.face_detector = RetinaFacePredictor(
device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name)
)
def __call__(self, video_frames):
landmarks = []
for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False)
if len(detected_faces) >= 1:
landmarks.append(np.reshape(detected_faces[0][:4], (2, 2)))
else:
landmarks.append(None)
return landmarks
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
import argparse
import os
parser = argparse.ArgumentParser(description="Merge labels")
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Specify the dataset used in the experiment",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Specify the subset of the dataset used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
required=True,
help="Directory of saved mouth patches or embeddings",
)
parser.add_argument(
"--groups",
type=int,
required=True,
help="Number of threads for parallel processing",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segments",
)
args = parser.parse_args()
dataset = args.dataset
subset = args.subset
seg_duration = args.seg_duration
# Check that there is more than one group
assert args.groups > 1, "There is no need to use this script for merging when --groups is 1."
# Create the filename template for label files
label_template = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}",
)
lines = []
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
assert os.path.exists(label_filename), f"{label_filename} does not exist."
with open(label_filename, "r") as file:
lines.extend(file.read().splitlines())
# Write the merged labels to a new file
dst_label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv",
)
with open(dst_label_filename, "w") as file:
file.write("\n".join(lines))
# Print the number of files and total duration in hours
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 # simplified from /3600./25.
print(f"The completed set has {len(lines)} files with a total of {total_duration:.2f} hours.")
# Remove the label files for each job index
print("** Remove the temporary label files **")
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
if os.path.exists(label_filename):
os.remove(label_filename)
print("** Finish **")
import argparse
import glob
import math
import os
import shutil
import warnings
import ffmpeg
from data.data_module import AVSRDataLoader
from tqdm import tqdm
from utils import save_vid_aud_txt, split_file
warnings.filterwarnings("ignore")
# Argument Parsing
parser = argparse.ArgumentParser(description="LRS3 Preprocessing")
parser.add_argument(
"--data-dir",
type=str,
help="The directory for sequence.",
)
parser.add_argument(
"--detector",
type=str,
default="retinaface",
help="Face detector used in the experiment.",
)
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of cropped-face dataset.",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Subset of the dataset used in the experiment.",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segment in seconds.",
)
parser.add_argument(
"--groups",
type=int,
default=1,
help="Number of threads to be used in parallel.",
)
parser.add_argument(
"--job-index",
type=int,
default=0,
help="Index to identify separate jobs (useful for parallel processing).",
)
args = parser.parse_args()
seg_duration = args.seg_duration
dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(modality="video", detector=args.detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25
seg_aud_len = seg_duration * 16000
label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1
else f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
)
os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created")
f = open(label_filename, "w")
# Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
if args.subset == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.subset, "**", "*.mp4"), recursive=True)
elif args.subset == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort()
else:
raise NotImplementedError
unit = math.ceil(len(filenames) * 1.0 / args.groups)
filenames = filenames[args.job_index * unit : (args.job_index + 1) * unit]
for data_filename in tqdm(filenames):
try:
video_data = vid_dataloader.load_data(data_filename)
audio_data = aud_dataloader.load_data(data_filename)
except UnboundLocalError:
continue
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
trim_vid_data, trim_aud_data = video_data, audio_data
text_line_list = open(data_filename[:-4] + ".txt", "r").read().splitlines()[0].split(" ")
text_line = " ".join(text_line_list[2:])
content = text_line.replace("}", "").replace("{", "")
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
continue
splitted = split_file(data_filename[:-4] + ".txt", max_frames=seg_vid_len)
for i in range(len(splitted)):
if len(splitted) == 1:
content, start, end, duration = splitted[i]
trim_vid_data, trim_aud_data = video_data, audio_data
else:
content, start, end, duration = splitted[i]
start_idx, end_idx = int(start * 25), int(end * 25)
try:
trim_vid_data, trim_aud_data = (
video_data[start_idx:end_idx],
audio_data[:, start_idx * 640 : end_idx * 640],
)
except TypeError:
continue
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}_{i:02d}.txt"
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
f.close()
tqdm
scikit-image
opencv-python
ffmpeg-python
## Face Recognition
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. You can install directly from github repositories or by using compressed files.
### Option 1. Install from github repositories
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
```Shell
git clone https://github.com/hhj1897/face_detection.git
cd face_detection
git lfs pull
pip install -e .
cd ..
```
### Option 2. Install by using compressed files
If you are experiencing over-quota issues for the above repositoies, you can download both packages [ibug.face_detection](https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip), unzip the files, and then run `pip install -e .` to install each package.
```Shell
wget https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip -O ./face_detection.zip
unzip -o ./face_detection.zip -d ./
cd face_detection
pip install -e .
cd ..
```
import os
import torchaudio
import torchvision
def split_file(filename, max_frames=600, fps=25.0):
lines = open(filename).read().splitlines()
flag = 0
stack = []
res = []
tmp = 0
start_timestamp = 0.0
threshold = max_frames / fps
for line in lines:
if "WORD START END ASDSCORE" in line:
flag = 1
continue
if flag:
word, start, end, score = line.split(" ")
start, end, score = float(start), float(end), float(score)
if end < tmp + threshold:
stack.append(word)
last_timestamp = end
else:
res.append([" ".join(stack), start_timestamp, last_timestamp, last_timestamp - start_timestamp])
tmp = start
start_timestamp = start
stack = [word]
if stack:
res.append([" ".join(stack), start_timestamp, end, end - start_timestamp])
return res
def save_vid_txt(dst_vid_filename, dst_txt_filename, trim_video_data, content, video_fps=25):
# -- save video
save2vid(dst_vid_filename, trim_video_data, video_fps)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save_vid_aud(
dst_vid_filename, dst_aud_filename, trim_vid_data, trim_aud_data, video_fps=25, audio_sample_rate=16000
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
def save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save2vid(filename, vid, frames_per_second):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchvision.io.write_video(filename, vid, frames_per_second)
def save2aud(filename, aud, sample_rate):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchaudio.save(filename, aud, sample_rate)
import logging
from argparse import ArgumentParser
import sentencepiece as spm
import torch
import torchaudio
from transforms import get_data_module
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
ckpt = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)["state_dict"]
model.load_state_dict(ckpt)
model.eval()
return model
def run_eval(model, data_module):
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][-1]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
return total_edit_distance / total_length
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--modality",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--root-dir",
type=str,
help="Root directory to LRS3 audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to sentencepiece model.",
required=True,
)
parser.add_argument(
"--checkpoint-path",
type=str,
help="Path to a checkpoint model.",
required=True,
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
run_eval(model, data_module)
if __name__ == "__main__":
cli_main()
import itertools
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["inputs", "input_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
if args.modality == "video":
self.frontend = video_resnet()
if args.modality == "audio":
self.frontend = audio_resnet()
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
# -- initialise
if args.pretrained_model_path:
ckpt = torch.load(args.pretrained_model_path, map_location=lambda storage, loc: storage)
tmp_ckpt = {
k.replace("encoder.frontend.", ""): v for k, v in ckpt.items() if k.startswith("encoder.frontend.")
}
self.frontend.load_state_dict(tmp_ckpt)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.frontend.parameters(), self.model.parameters()])),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
features = self.frontend(batch.inputs)
output, src_lengths, _, _ = self.model(
features, batch.input_lengths, prepended_targets, prepended_target_lengths
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import itertools
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.fusion import fusion_module
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
AVBatch = namedtuple("AVBatch", ["audios", "videos", "audio_lengths", "video_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class AVConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
self.audio_frontend = audio_resnet()
self.video_frontend = video_resnet()
self.fusion = fusion_module()
frontend_params = [self.video_frontend.parameters(), self.audio_frontend.parameters()]
fusion_params = [self.fusion.parameters()]
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.model.parameters()] + frontend_params + fusion_params)),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
video_features = self.video_frontend(batch.videos)
audio_features = self.audio_frontend(batch.audios)
output, src_lengths, _, _ = self.model(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device))
hypotheses = decoder(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths.to(self.device),
beam_width=20,
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import os
import torchaudio
import torchvision
from torch.utils.data import Dataset
def _load_list(args, *filenames):
output = []
length = []
for filename in filenames:
filepath = os.path.join(args.root_dir, "labels", filename)
for line in open(filepath).read().splitlines():
dataset, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2]
path = os.path.normpath(os.path.join(args.root_dir, dataset, rel_path[:-4] + ".mp4"))
length.append(int(input_length))
output.append(path)
return output, length
def load_video(path):
"""
rtype: torch, T x C x H x W
"""
vid = torchvision.io.read_video(path, pts_unit="sec", output_format="THWC")[0]
vid = vid.permute((0, 3, 1, 2))
return vid
def load_audio(path):
"""
rtype: torch, T x 1
"""
waveform, sample_rate = torchaudio.load(path, normalize=True)
return waveform.transpose(1, 0)
def load_transcript(path):
transcript_path = path.replace("video_seg", "text_seg")[:-4] + ".txt"
return open(transcript_path).read().splitlines()[0]
def load_item(path, modality):
if modality == "video":
return (load_video(path), load_transcript(path))
if modality == "audio":
return (load_audio(path), load_transcript(path))
if modality == "audiovisual":
return (load_audio(path), load_video(path), load_transcript(path))
class LRS3(Dataset):
def __init__(
self,
args,
subset: str = "train",
) -> None:
if subset is not None and subset not in ["train", "val", "test"]:
raise ValueError("When `subset` is not None, it must be one of ['train', 'val', 'test'].")
self.args = args
if subset == "train":
self.files, self.lengths = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val":
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test":
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n):
path = self.files[n]
return load_item(path, self.args.modality)
def __len__(self) -> int:
return len(self.files)
from torchaudio.prototype.models import conformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/prototype/models/rnnt.html#conformer_rnnt_model
def conformer_rnnt():
return conformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
time_reduction_stride=1,
conformer_input_dim=256,
conformer_ffn_dim=1024,
conformer_num_layers=16,
conformer_num_heads=4,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
num_symbols=1024,
symbol_embedding_dim=256,
num_lstm_layers=2,
lstm_hidden_dim=512,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-5,
lstm_dropout=0.3,
joiner_activation="tanh",
)
from torchaudio.models.rnnt import emformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base
def emformer_rnnt():
return emformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
num_symbols=1024,
segment_length=64,
right_context_length=0,
time_reduction_input_dim=128,
time_reduction_stride=1,
transformer_num_heads=4,
transformer_ffn_dim=2048,
transformer_num_layers=20,
transformer_dropout=0.1,
transformer_activation="gelu",
transformer_left_context_length=30,
transformer_max_memory_size=0,
transformer_weight_init_scale_strategy="depthwise",
transformer_tanh_on_mem=True,
symbol_embedding_dim=512,
num_lstm_layers=3,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-3,
lstm_dropout=0.3,
)
import torch
class FeedForwardModule(torch.nn.Module):
r"""Positionwise feed forward layer.
Args:
input_dim (int): input dimension.
hidden_dim (int): hidden dimension.
dropout (float, optional): dropout probability. (Default: 0.0)
"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.0) -> None:
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.LayerNorm(input_dim),
torch.nn.Linear(input_dim, hidden_dim, bias=True),
torch.nn.SiLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_dim, output_dim, bias=True),
torch.nn.Dropout(dropout),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""
Args:
input (torch.Tensor): with shape `(*, D)`.
Returns:
torch.Tensor: output, with shape `(*, D)`.
"""
return self.sequential(input)
def fusion_module(input_dim=1024, hidden_dim=3072, output_dim=512, dropout=0.1):
return FeedForwardModule(input_dim, hidden_dim, output_dim, dropout)
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv2d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(outplanes),
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="swish",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
):
super(ResNet, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
# -- auxiliary functions
def threeD_to_2D_tensor(x):
n_batch, n_channels, s_time, sx, sy = x.shape
x = x.transpose(1, 2)
return x.reshape(n_batch * s_time, n_channels, sx, sy)
class Conv3dResNet(nn.Module):
"""Conv3dResNet module"""
def __init__(self, backbone_type="resnet", relu_type="swish"):
"""__init__.
:param backbone_type: str, the type of a visual front-end.
:param relu_type: str, activation function used in an audio front-end.
"""
super(Conv3dResNet, self).__init__()
self.backbone_type = backbone_type
self.frontend_nout = 64
self.trunk = ResNet(
BasicBlock,
[2, 2, 2, 2],
relu_type=relu_type,
)
# -- frontend3D
if relu_type == "relu":
frontend_relu = nn.ReLU(True)
elif relu_type == "prelu":
frontend_relu = nn.PReLU(self.frontend_nout)
elif relu_type == "swish":
frontend_relu = nn.SiLU(inplace=True)
self.frontend3D = nn.Sequential(
nn.Conv3d(
in_channels=1,
out_channels=self.frontend_nout,
kernel_size=(5, 7, 7),
stride=(1, 2, 2),
padding=(2, 3, 3),
bias=False,
),
nn.BatchNorm3d(self.frontend_nout),
frontend_relu,
nn.MaxPool3d(
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
),
)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences.
"""
# -- include Channel dimension
xs_pad = xs_pad.transpose(2, 1)
B, C, T, H, W = xs_pad.size()
xs_pad = self.frontend3D(xs_pad)
Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W
xs_pad = threeD_to_2D_tensor(xs_pad)
xs_pad = self.trunk(xs_pad)
xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1))
return xs_pad
def video_resnet():
return Conv3dResNet()
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv1d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv1d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm1d(outplanes),
)
class BasicBlock1D(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="relu",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock1D, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm1d(planes)
# type of ReLU is an input option
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet1D(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
a_upsample_ratio=1,
):
"""__init__.
:param block: torch.nn.Module, class of blocks.
:param layers: List, customised layers in each block.
:param relu_type: str, type of activation function.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(ResNet1D, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.a_upsample_ratio = a_upsample_ratio
self.conv1 = nn.Conv1d(
in_channels=1,
out_channels=self.inplanes,
kernel_size=80,
stride=4,
padding=38,
bias=False,
)
self.bn1 = nn.BatchNorm1d(self.inplanes)
if relu_type == "relu":
self.relu = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu = nn.PReLU(num_parameters=self.inplanes)
elif relu_type == "swish":
self.relu = nn.SiLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool1d(
kernel_size=20 // self.a_upsample_ratio,
stride=20 // self.a_upsample_ratio,
)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
return x
class Conv1dResNet(nn.Module):
"""Conv1dResNet"""
def __init__(self, relu_type="swish", a_upsample_ratio=1):
"""__init__.
:param relu_type: str, Activation function used in an audio front-end.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(Conv1dResNet, self).__init__()
self.a_upsample_ratio = a_upsample_ratio
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
"""
B, T, C = xs_pad.size()
xs_pad = xs_pad[:, : T // 640 * 640, :]
xs_pad = xs_pad.transpose(1, 2)
xs_pad = self.trunk(xs_pad)
# -- from B x C x T to B x T x C
xs_pad = xs_pad.transpose(1, 2)
return xs_pad
def audio_resnet():
return Conv1dResNet()
import math
import torch
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
steps_per_epoch: int,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_epochs * steps_per_epoch
self.total_steps = total_epochs * steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.warmup_steps:
return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
else:
decay_steps = self.total_steps - self.warmup_steps
return [
0.5 * base_lr * (1 + math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps))
for base_lr in self.base_lrs
]
import logging
import os
from argparse import ArgumentParser
import sentencepiece as spm
from average_checkpoints import ensemble
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def get_trainer(args):
seed_everything(1)
checkpoint = ModelCheckpoint(
dirpath=os.path.join(args.exp_dir, args.exp_name) if args.exp_dir else None,
monitor="monitoring_step",
mode="max",
save_last=True,
filename="{epoch}",
save_top_k=10,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
lr_monitor,
]
return Trainer(
sync_batchnorm=True,
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
devices=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
return model
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--modality",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--root-dir",
type=str,
help="Root directory to LRS3 audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument(
"--exp-dir",
default="./exp",
type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--exp-name",
type=str,
help="Experiment name",
)
parser.add_argument(
"--num-nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=55,
type=int,
help="Number of epochs to train for. (Default: 55)",
)
parser.add_argument(
"--resume-from-checkpoint",
default=None,
type=str,
help="Path to the checkpoint to resume from",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to use debug level for logging",
)
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
trainer = get_trainer(args)
trainer.fit(model, data_module)
ensemble(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LRS3 pretrain and trainval.
- `[lrs3_path]` is the directory path for the LRS3 cropped face dataset.
Example:
python train_spm.py --lrs3-path [lrs3_path]
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
return [open(transcript_path).read().splitlines()[0].lower()]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--lrs3-path",
type=pathlib.Path,
help="Path to LRS3 datasets.",
required=True,
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.lrs3_path / "LRS3_text_seg16s"
splits = ["pretrain", "trainval"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment