"vscode:/vscode.git/clone" did not exist on "067b9dcabc594b5c31662ad80965ea6933eba510"
Commit d4644793 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

Update avsr recipe (#3493)

Summary:
This PR is to include few changes in the AV-ASR recipe. The changes include better results, a faster face detector (Mediapipe), renamed variable names, a streamlined dataloader, and a few illustrated examples. These changes were made to improve the usability of the recipe.

Pull Request resolved: https://github.com/pytorch/audio/pull/3493

Reviewed By: mthrok

Differential Revision: D47758072

Pulled By: mpc001

fbshipit-source-id: 4533587776f3a7a74f3f11b0ece773a0934bacdc
parent 56e22664
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p>
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1>
<p align="center"><img width="160" src="https://download.pytorch.org/torchaudio/doc-assets/avsr/lip_white.png" alt="logo"></p>
<h1 align="center">Real-time ASR/VSR/AV-ASR Examples</h1>
This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307).
<div align="center">
[📘Introduction](#introduction) |
[📊Training](#Training) |
[🔮Evaluation](#Evaluation)
</div>
## Introduction
This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307).
Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera.
## Preparation
1. Setup the environment.
```
conda create -y -n autoavsr python=3.8
conda activate autoavsr
```
2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
```Shell
pip install pytorch-lightning sentencepiece
pip install torch torchvision torchaudio pytorch-lightning sentencepiece
```
3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder.
2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder.
4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`.
### Training ASR or VSR model
- `[root_dir]` is the root directory for the LRS3 cropped-face dataset.
- `[modality]` is the input modality type, including `v`, `a`, and `av`.
- `[mode]` is the model type, including `online` and `offline`.
## Usage
### Training
```Shell
python train.py --root-dir [root_dir] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md [modality] \
--mode [mode]
python train.py --exp-dir=[exp_dir] \
--exp-name=[exp_name] \
--modality=[modality] \
--mode=[mode] \
--root-dir=[root-dir] \
--sp-model-path=[sp_model_path] \
--num-nodes=[num_nodes] \
--gpus=[gpus]
```
### Training AV-ASR model
- `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`.
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`.
- `num-nodes`: The number of machines used. Default: 4.
- `gpus`: The number of gpus in each machine. Default: 8.
### Evaluation
```Shell
python train.py --root-dir [root-dir] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md av \
--mode [mode]
python eval.py --modality=[modality] \
--mode=[mode] \
--root-dir=[dataset_path] \
--sp-model-path=[sp_model_path] \
--checkpoint-path=[checkpoint_path]
```
### Evaluating models
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`.
- `checkpoint-path`: Path to a pretraned model.
```Shell
python eval.py --dataset-path [dataset_path] \
--sp-model-path ./spm_unigram_1023.model
--md [modality] \
--mode [mode] \
--checkpoint-path [checkpoint_path]
```
## Results
The table below contains WER for AV-ASR models [offline evaluation].
The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation].
| Model | WER [%] | Params (M) |
|:-----------:|:------------:|:--------------:|
| Non-streaming models | |
| AV-ASR | 4.0 | 50 |
| Streaming models | |
| AV-ASR | 4.3 | 40 |
| Model | Training dataset (hours) | WER [%] | Params (M) |
|:--------------------:|:------------------------:|:-------:|:----------:|
| Non-streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 50 |
| Streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 40 |
......@@ -23,9 +23,6 @@ def average_checkpoints(last):
def ensemble(args):
last = [
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
for n in range(args.epochs - 10, args.epochs)
]
model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
......@@ -110,52 +110,19 @@ class LRS3DataModule(LightningDataModule):
self.num_workers = num_workers
def train_dataloader(self):
datasets = [LRS3(self.args, subset="train")]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
dataset = LRS3(self.args, subset="train")
dataset = CustomBucketDataset(
dataset, dataset.lengths, self.max_frames, self.train_num_buckets, batch_size=self.batch_size
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
)
return dataloader
def val_dataloader(self):
datasets = [LRS3(self.args, subset="val")]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = LRS3(self.args, subset="val")
dataset = CustomBucketDataset(dataset, dataset.lengths, self.max_frames, 1, batch_size=self.batch_size)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
......
# Preprocessing LRS3
# Pre-process LRS3
We provide a pre-processing pipeline to detect and crop full-face images in this repository.
We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3.
## Prerequisites
## Introduction
Install all dependency-packages.
Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image.
<div align="center">
<table style="display: inline-table;">
<tr><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif", width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif" width="144"></td></tr>
<tr><td>0. Original</td> <td>1. Detection</td> <td>2. Transformation</td> <td>3. Face ROIs</td> </tr>
</table>
</div>
## Preparation
1. Install all dependency-packages.
```Shell
pip install -r requirements.txt
```
Install [RetinaFace](./tools) tracker.
2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it.
## Preprocessing LRS3
## Preprocessing
To pre-process the LRS3 dataset, plrase follow these steps:
### Step 1. Pre-process the LRS3 dataset.
Please run the following script to pre-process the LRS3 dataset:
1. Download the LRS3 dataset from the official website.
2. Run the following command to preprocess the dataset:
```Shell
python main.py \
python preprocess_lrs3.py \
--data-dir=[data_dir] \
--detector=[detector] \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
--job-index=[job_index]
--root-dir=[root] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n] \
--job-index=[j]
```
- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks.
- `[root]` is the directory for saved cropped-face dataset.
- `[folder]` can be set to `train` or `test`.
- `data-dir`: Path to the directory containing video files.
- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`.
- `dataset`: Name of the dataset. Valid value is: `lrs3`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `subset`: Name of the subset. Valid values are: `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`.
- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`).
### Step 2. Merge the label list.
After completing Step 2, run the following script to merge all labels.
3. Run the following command to merge all labels:
```Shell
python merge.py \
--root-dir=[root_dir] \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n]
```
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`.
- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
......@@ -19,6 +19,12 @@ class AVSRDataLoader:
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
if detector == "mediapipe":
from detectors.mediapipe.detector import LandmarksDetector
from detectors.mediapipe.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True):
if self.modality == "audio":
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import mediapipe as mp
import numpy as np
import torchvision
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self):
self.mp_face_detection = mp.solutions.face_detection
self.short_range_detector = self.mp_face_detection.FaceDetection(
min_detection_confidence=0.5, model_selection=0
)
self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1)
def __call__(self, video_frames):
landmarks = self.detect(video_frames, self.full_range_detector)
if all(element is None for element in landmarks):
landmarks = self.detect(video_frames, self.short_range_detector)
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks
def detect(self, filename, detector):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = []
for frame in video_frames:
results = detector.process(frame)
if not results.detections:
landmarks.append(None)
continue
face_points = []
for idx, detected_faces in enumerate(results.detections):
max_id, max_size = 0, 0
bboxC = detected_faces.location_data.relative_bounding_box
ih, iw, ic = frame.shape
bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
if bbox_size > max_size:
max_id, max_size = idx, bbox_size
lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]]
face_points.append(lmx)
landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2)))
return landmarks
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
import argparse
import os
from argparse import ArgumentParser
def load_args(default_config=None):
parser = ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--subset",
type=str,
help="Specify the set used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of saved mouth patches or embeddings.",
)
parser.add_argument(
"--groups",
type=int,
help="Specify the number of threads to be used",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Specify the segment length",
)
args = parser.parse_args()
return args
args = load_args()
parser = argparse.ArgumentParser(description="Merge labels")
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Specify the dataset used in the experiment",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Specify the subset of the dataset used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
required=True,
help="Directory of saved mouth patches or embeddings",
)
parser.add_argument(
"--groups",
type=int,
required=True,
help="Number of threads for parallel processing",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segments",
)
args = parser.parse_args()
dataset = args.dataset
subset = args.subset
......@@ -45,7 +43,9 @@ assert args.groups > 1, "There is no need to use this script for merging when --
# Create the filename template for label files
label_template = os.path.join(
args.root_dir, "labels", f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}"
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}",
)
lines = []
......@@ -58,15 +58,17 @@ for job_index in range(args.groups):
# Write the merged labels to a new file
dst_label_filename = os.path.join(
args.root_dir, dataset, f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv"
args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv",
)
with open(dst_label_filename, "w") as file:
file.write("\n".join(lines))
# Print the number of files and total duration in hours
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0
print(f"The completed set has {len(lines)} files with a total of {total_duration} hours.")
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 # simplified from /3600./25.
print(f"The completed set has {len(lines)} files with a total of {total_duration:.2f} hours.")
# Remove the label files for each job index
print("** Remove the temporary label files **")
......
import argparse
import glob
import math
import os
import shutil
import warnings
import ffmpeg
......@@ -12,52 +12,60 @@ from utils import save_vid_aud_txt, split_file
warnings.filterwarnings("ignore")
from argparse import ArgumentParser
def load_args(default_config=None):
parser = ArgumentParser(description="Preprocess LRS3 to crop full-face images")
# -- for benchmark evaluation
parser.add_argument(
"--data-dir",
type=str,
help="The directory for sequence.",
)
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of cropped-face dataset.",
)
parser.add_argument("--job-index", type=int, default=0, help="job index")
parser.add_argument(
"--groups",
type=int,
default=1,
help="specify the number of threads to be used",
)
parser.add_argument(
"--folder",
type=str,
default="test",
help="specify the set used in the experiment",
)
args = parser.parse_args()
return args
args = load_args()
seg_duration = 16
detector = "retinaface"
# Argument Parsing
parser = argparse.ArgumentParser(description="LRS3 Preprocessing")
parser.add_argument(
"--data-dir",
type=str,
help="The directory for sequence.",
)
parser.add_argument(
"--detector",
type=str,
default="retinaface",
help="Face detector used in the experiment.",
)
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of cropped-face dataset.",
)
parser.add_argument(
"--subset",
type=str,
required=True,
help="Subset of the dataset used in the experiment.",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Length of the segment in seconds.",
)
parser.add_argument(
"--groups",
type=int,
default=1,
help="Number of threads to be used in parallel.",
)
parser.add_argument(
"--job-index",
type=int,
default=0,
help="Index to identify separate jobs (useful for parallel processing).",
)
args = parser.parse_args()
seg_duration = args.seg_duration
dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(modality="video", detector=detector, resize=(96, 96))
vid_dataloader = AVSRDataLoader(modality="video", detector=args.detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25
......@@ -66,9 +74,9 @@ seg_aud_len = seg_duration * 16000
label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.csv"
f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1
else f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
else f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
)
os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created")
......@@ -77,9 +85,9 @@ f = open(label_filename, "w")
# Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
if args.folder == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.folder, "**", "*.mp4"), recursive=True)
elif args.folder == "train":
if args.subset == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.subset, "**", "*.mp4"), recursive=True)
elif args.subset == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort()
......@@ -96,7 +104,7 @@ for data_filename in tqdm(filenames):
except UnboundLocalError:
continue
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test", "main"]:
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
......
tqdm
scikit-image
opencv-python
ffmpeg-python
## Face Recognition
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository.
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. You can install directly from github repositories or by using compressed files.
### Option 1. Install from github repositories
### Prerequisites
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
### From source
1. Install *`ibug.face_detection`*
```Shell
git clone https://github.com/hhj1897/face_detection.git
cd face_detection
......@@ -16,3 +14,15 @@ git lfs pull
pip install -e .
cd ..
```
### Option 2. Install by using compressed files
If you are experiencing over-quota issues for the above repositoies, you can download both packages [ibug.face_detection](https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip), unzip the files, and then run `pip install -e .` to install each package.
```Shell
wget https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip -O ./face_detection.zip
unzip -o ./face_detection.zip -d ./
cd face_detection
pip install -e .
cd ..
```
......@@ -16,7 +16,7 @@ def compute_word_level_distance(seq1, seq2):
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av":
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
......@@ -49,7 +49,7 @@ def run_eval(model, data_module):
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--md",
"--modality",
type=str,
help="Modality",
required=True,
......@@ -69,20 +69,15 @@ def parse_args():
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to SentencePiece model.",
help="Path to sentencepiece model.",
required=True,
)
parser.add_argument(
"--checkpoint-path",
type=str,
help="Path to checkpoint model.",
help="Path to a checkpoint model.",
required=True,
)
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
......
......@@ -57,9 +57,9 @@ class ConformerRNNTModule(LightningModule):
)
self.blank_idx = spm_vocab_size
if args.md == "v":
if args.modality == "video":
self.frontend = video_resnet()
if args.md == "a":
if args.modality == "audio":
self.frontend = audio_resnet()
if args.mode == "online":
......@@ -116,33 +116,13 @@ class ConformerRNNTModule(LightningModule):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch: Batch):
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
......@@ -157,7 +137,7 @@ class ConformerRNNTModule(LightningModule):
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", self.global_step)
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
......
......@@ -116,7 +116,7 @@ class AVConformerRNNTModule(LightningModule):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch: AVBatch):
def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device))
......@@ -127,27 +127,7 @@ class AVConformerRNNTModule(LightningModule):
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: AVBatch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
......@@ -162,7 +142,7 @@ class AVConformerRNNTModule(LightningModule):
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", self.global_step)
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
......
......@@ -40,12 +40,12 @@ def load_transcript(path):
return open(transcript_path).read().splitlines()[0]
def load_item(path, md):
if md == "v":
def load_item(path, modality):
if modality == "video":
return (load_video(path), load_transcript(path))
if md == "a":
if modality == "audio":
return (load_audio(path), load_transcript(path))
if md == "av":
if modality == "audiovisual":
return (load_audio(path), load_video(path), load_transcript(path))
......@@ -62,15 +62,15 @@ class LRS3(Dataset):
self.args = args
if subset == "train":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
self.files, self.lengths = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n):
path = self._filelist[n]
return load_item(path, self.args.md)
path = self.files[n]
return load_item(path, self.args.modality)
def __len__(self) -> int:
return len(self._filelist)
return len(self.files)
......@@ -32,5 +32,5 @@ class FeedForwardModule(torch.nn.Module):
return self.sequential(input)
def fusion_module():
return FeedForwardModule(1024, 3072, 512, 0.1)
def fusion_module(input_dim=1024, hidden_dim=3072, output_dim=512, dropout=0.1):
return FeedForwardModule(input_dim, hidden_dim, output_dim, dropout)
......@@ -14,7 +14,7 @@ def get_trainer(args):
seed_everything(1)
checkpoint = ModelCheckpoint(
dirpath=os.path.join(args.exp_dir, args.experiment_name) if args.exp_dir else None,
dirpath=os.path.join(args.exp_dir, args.exp_name) if args.exp_dir else None,
monitor="monitoring_step",
mode="max",
save_last=True,
......@@ -36,13 +36,12 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
resume_from_checkpoint=args.resume_from_checkpoint,
)
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av":
if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
......@@ -56,7 +55,7 @@ def get_lightning_module(args):
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--md",
"--modality",
type=str,
help="Modality",
required=True,
......@@ -86,19 +85,20 @@ def parse_args():
)
parser.add_argument(
"--exp-dir",
default="./exp",
type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--experiment-name",
"--exp-name",
type=str,
help="Experiment name",
)
parser.add_argument(
"--num-nodes",
default=8,
default=4,
type=int,
help="Number of nodes to use for training. (Default: 8)",
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
......@@ -113,9 +113,16 @@ def parse_args():
help="Number of epochs to train for. (Default: 55)",
)
parser.add_argument(
"--resume-from-checkpoint", default=None, type=str, help="Path to the checkpoint to resume from"
"--resume-from-checkpoint",
default=None,
type=str,
help="Path to the checkpoint to resume from",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to use debug level for logging",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
......
......@@ -55,28 +55,28 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = []
raw_audios = []
for sample in samples:
if args.md == "v":
if args.modality == "visual":
raw_videos.append(sample[0])
if args.md == "a":
if args.modality == "audio":
raw_audios.append(sample[0])
if args.md == "av":
if args.modality == "audiovisual":
length = min(len(sample[0]) // 640, len(sample[1]))
raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length])
if args.md == "v" or args.md == "av":
if args.modality == "visual" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
if args.md == "a" or args.md == "av":
if args.modality == "audio" or args.modality == "audiovisual":
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.md == "v":
if args.modality == "visual":
return videos, video_lengths
if args.md == "a":
if args.modality == "audio":
return audios, audio_lengths
if args.md == "av":
if args.modality == "audiovisual":
return audios, videos, audio_lengths, video_lengths
......@@ -100,17 +100,17 @@ class TrainTransform:
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a":
if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v":
if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av":
if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
......@@ -135,17 +135,17 @@ class ValTransform:
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a":
if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v":
if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av":
if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment