Commit d4644793 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

Update avsr recipe (#3493)

Summary:
This PR is to include few changes in the AV-ASR recipe. The changes include better results, a faster face detector (Mediapipe), renamed variable names, a streamlined dataloader, and a few illustrated examples. These changes were made to improve the usability of the recipe.

Pull Request resolved: https://github.com/pytorch/audio/pull/3493

Reviewed By: mthrok

Differential Revision: D47758072

Pulled By: mpc001

fbshipit-source-id: 4533587776f3a7a74f3f11b0ece773a0934bacdc
parent 56e22664
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p> <p align="center"><img width="160" src="https://download.pytorch.org/torchaudio/doc-assets/avsr/lip_white.png" alt="logo"></p>
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1> <h1 align="center">Real-time ASR/VSR/AV-ASR Examples</h1>
This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307). <div align="center">
[📘Introduction](#introduction) |
[📊Training](#Training) |
[🔮Evaluation](#Evaluation)
</div>
## Introduction
This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307).
Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera.
## Preparation ## Preparation
1. Setup the environment.
```
conda create -y -n autoavsr python=3.8
conda activate autoavsr
```
2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages: 1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
```Shell ```Shell
pip install pytorch-lightning sentencepiece pip install torch torchvision torchaudio pytorch-lightning sentencepiece
``` ```
3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder. 2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder.
4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`. ## Usage
### Training ASR or VSR model
- `[root_dir]` is the root directory for the LRS3 cropped-face dataset.
- `[modality]` is the input modality type, including `v`, `a`, and `av`.
- `[mode]` is the model type, including `online` and `offline`.
### Training
```Shell ```Shell
python train.py --exp-dir=[exp_dir] \
python train.py --root-dir [root_dir] \ --exp-name=[exp_name] \
--sp-model-path ./spm_unigram_1023.model --modality=[modality] \
--exp-dir ./exp \ --mode=[mode] \
--num-nodes 8 \ --root-dir=[root-dir] \
--gpus 8 \ --sp-model-path=[sp_model_path] \
--md [modality] \ --num-nodes=[num_nodes] \
--mode [mode] --gpus=[gpus]
``` ```
### Training AV-ASR model - `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`.
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`.
- `num-nodes`: The number of machines used. Default: 4.
- `gpus`: The number of gpus in each machine. Default: 8.
### Evaluation
```Shell ```Shell
python train.py --root-dir [root-dir] \ python eval.py --modality=[modality] \
--sp-model-path ./spm_unigram_1023.model --mode=[mode] \
--exp-dir ./exp \ --root-dir=[dataset_path] \
--num-nodes 8 \ --sp-model-path=[sp_model_path] \
--gpus 8 \ --checkpoint-path=[checkpoint_path]
--md av \
--mode [mode]
``` ```
### Evaluating models - `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`.
- `checkpoint-path`: Path to a pretraned model.
```Shell ## Results
python eval.py --dataset-path [dataset_path] \
--sp-model-path ./spm_unigram_1023.model
--md [modality] \
--mode [mode] \
--checkpoint-path [checkpoint_path]
```
The table below contains WER for AV-ASR models [offline evaluation]. The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation].
| Model | WER [%] | Params (M) | | Model | Training dataset (hours) | WER [%] | Params (M) |
|:-----------:|:------------:|:--------------:| |:--------------------:|:------------------------:|:-------:|:----------:|
| Non-streaming models | | | Non-streaming models | | | |
| AV-ASR | 4.0 | 50 | | AV-ASR | LRS3 (438) | 3.9 | 50 |
| Streaming models | | | Streaming models | | | |
| AV-ASR | 4.3 | 40 | | AV-ASR | LRS3 (438) | 3.9 | 40 |
...@@ -23,9 +23,6 @@ def average_checkpoints(last): ...@@ -23,9 +23,6 @@ def average_checkpoints(last):
def ensemble(args): def ensemble(args):
last = [ last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt") model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
for n in range(args.epochs - 10, args.epochs)
]
model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path) torch.save({"state_dict": average_checkpoints(last)}, model_path)
...@@ -110,52 +110,19 @@ class LRS3DataModule(LightningDataModule): ...@@ -110,52 +110,19 @@ class LRS3DataModule(LightningDataModule):
self.num_workers = num_workers self.num_workers = num_workers
def train_dataloader(self): def train_dataloader(self):
datasets = [LRS3(self.args, subset="train")] dataset = LRS3(self.args, subset="train")
dataset = CustomBucketDataset(
if not self.train_dataset_lengths: dataset, dataset.lengths, self.max_frames, self.train_num_buckets, batch_size=self.batch_size
self.train_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
) )
dataset = TransformDataset(dataset, self.train_transform) dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
) )
return dataloader return dataloader
def val_dataloader(self): def val_dataloader(self):
datasets = [LRS3(self.args, subset="val")] dataset = LRS3(self.args, subset="val")
dataset = CustomBucketDataset(dataset, dataset.lengths, self.max_frames, 1, batch_size=self.batch_size)
if not self.val_dataset_lengths:
self.val_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform) dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader return dataloader
......
# Preprocessing LRS3 # Pre-process LRS3
We provide a pre-processing pipeline to detect and crop full-face images in this repository. We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3.
## Prerequisites ## Introduction
Install all dependency-packages. Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image.
<div align="center">
<table style="display: inline-table;">
<tr><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif", width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif" width="144"></td></tr>
<tr><td>0. Original</td> <td>1. Detection</td> <td>2. Transformation</td> <td>3. Face ROIs</td> </tr>
</table>
</div>
## Preparation
1. Install all dependency-packages.
```Shell ```Shell
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Install [RetinaFace](./tools) tracker. 2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it.
## Preprocessing LRS3
## Preprocessing To pre-process the LRS3 dataset, plrase follow these steps:
### Step 1. Pre-process the LRS3 dataset. 1. Download the LRS3 dataset from the official website.
Please run the following script to pre-process the LRS3 dataset:
2. Run the following command to preprocess the dataset:
```Shell ```Shell
python main.py \ python preprocess_lrs3.py \
--data-dir=[data_dir] \ --data-dir=[data_dir] \
--detector=[detector] \
--dataset=[dataset] \ --dataset=[dataset] \
--root=[root] \ --root-dir=[root] \
--folder=[folder] \ --subset=[subset] \
--groups=[num_groups] \ --seg-duration=[seg_duration] \
--job-index=[job_index] --groups=[n] \
--job-index=[j]
``` ```
- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks. - `data-dir`: Path to the directory containing video files.
- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`.
- `[root]` is the directory for saved cropped-face dataset. - `dataset`: Name of the dataset. Valid value is: `lrs3`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `[folder]` can be set to `train` or `test`. - `subset`: Name of the subset. Valid values are: `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`.
- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`). 3. Run the following command to merge all labels:
### Step 2. Merge the label list.
After completing Step 2, run the following script to merge all labels.
```Shell ```Shell
python merge.py \ python merge.py \
--root-dir=[root_dir] \
--dataset=[dataset] \ --dataset=[dataset] \
--root=[root] \ --subset=[subset] \
--folder=[folder] \ --seg-duration=[seg_duration] \
--groups=[num_groups] \ --groups=[n]
``` ```
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`.
- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
...@@ -19,6 +19,12 @@ class AVSRDataLoader: ...@@ -19,6 +19,12 @@ class AVSRDataLoader:
self.landmarks_detector = LandmarksDetector(device="cuda:0") self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize) self.video_process = VideoProcess(resize=resize)
if detector == "mediapipe":
from detectors.mediapipe.detector import LandmarksDetector
from detectors.mediapipe.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True): def load_data(self, data_filename, transform=True):
if self.modality == "audio": if self.modality == "audio":
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import mediapipe as mp
import numpy as np
import torchvision
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self):
self.mp_face_detection = mp.solutions.face_detection
self.short_range_detector = self.mp_face_detection.FaceDetection(
min_detection_confidence=0.5, model_selection=0
)
self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1)
def __call__(self, video_frames):
landmarks = self.detect(video_frames, self.full_range_detector)
if all(element is None for element in landmarks):
landmarks = self.detect(video_frames, self.short_range_detector)
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks
def detect(self, filename, detector):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = []
for frame in video_frames:
results = detector.process(frame)
if not results.detections:
landmarks.append(None)
continue
face_points = []
for idx, detected_faces in enumerate(results.detections):
max_id, max_size = 0, 0
bboxC = detected_faces.location_data.relative_bounding_box
ih, iw, ic = frame.shape
bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
if bbox_size > max_size:
max_id, max_size = idx, bbox_size
lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]]
face_points.append(lmx)
landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2)))
return landmarks
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
import argparse
import os import os
from argparse import ArgumentParser
parser = argparse.ArgumentParser(description="Merge labels")
def load_args(default_config=None): parser.add_argument(
parser = ArgumentParser() "--dataset",
parser.add_argument( type=str,
"--dataset", required=True,
type=str, help="Specify the dataset used in the experiment",
help="Specify the dataset name used in the experiment", )
) parser.add_argument(
parser.add_argument( "--subset",
"--subset", type=str,
type=str, required=True,
help="Specify the set used in the experiment", help="Specify the subset of the dataset used in the experiment",
) )
parser.add_argument( parser.add_argument(
"--root-dir", "--root-dir",
type=str, type=str,
help="The root directory of saved mouth patches or embeddings.", required=True,
) help="Directory of saved mouth patches or embeddings",
parser.add_argument( )
"--groups", parser.add_argument(
type=int, "--groups",
help="Specify the number of threads to be used", type=int,
) required=True,
parser.add_argument( help="Number of threads for parallel processing",
"--seg-duration", )
type=int, parser.add_argument(
default=16, "--seg-duration",
help="Specify the segment length", type=int,
) default=16,
args = parser.parse_args() help="Length of the segments",
return args )
args = parser.parse_args()
args = load_args()
dataset = args.dataset dataset = args.dataset
subset = args.subset subset = args.subset
...@@ -45,7 +43,9 @@ assert args.groups > 1, "There is no need to use this script for merging when -- ...@@ -45,7 +43,9 @@ assert args.groups > 1, "There is no need to use this script for merging when --
# Create the filename template for label files # Create the filename template for label files
label_template = os.path.join( label_template = os.path.join(
args.root_dir, "labels", f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}" args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}",
) )
lines = [] lines = []
...@@ -58,15 +58,17 @@ for job_index in range(args.groups): ...@@ -58,15 +58,17 @@ for job_index in range(args.groups):
# Write the merged labels to a new file # Write the merged labels to a new file
dst_label_filename = os.path.join( dst_label_filename = os.path.join(
args.root_dir, dataset, f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv" args.root_dir,
"labels",
f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv",
) )
with open(dst_label_filename, "w") as file: with open(dst_label_filename, "w") as file:
file.write("\n".join(lines)) file.write("\n".join(lines))
# Print the number of files and total duration in hours # Print the number of files and total duration in hours
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 # simplified from /3600./25.
print(f"The completed set has {len(lines)} files with a total of {total_duration} hours.") print(f"The completed set has {len(lines)} files with a total of {total_duration:.2f} hours.")
# Remove the label files for each job index # Remove the label files for each job index
print("** Remove the temporary label files **") print("** Remove the temporary label files **")
......
import argparse
import glob import glob
import math import math
import os import os
import shutil import shutil
import warnings import warnings
import ffmpeg import ffmpeg
...@@ -12,52 +12,60 @@ from utils import save_vid_aud_txt, split_file ...@@ -12,52 +12,60 @@ from utils import save_vid_aud_txt, split_file
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
from argparse import ArgumentParser # Argument Parsing
parser = argparse.ArgumentParser(description="LRS3 Preprocessing")
parser.add_argument(
def load_args(default_config=None): "--data-dir",
parser = ArgumentParser(description="Preprocess LRS3 to crop full-face images") type=str,
# -- for benchmark evaluation help="The directory for sequence.",
parser.add_argument( )
"--data-dir", parser.add_argument(
type=str, "--detector",
help="The directory for sequence.", type=str,
) default="retinaface",
parser.add_argument( help="Face detector used in the experiment.",
"--dataset", )
type=str, parser.add_argument(
help="Specify the dataset name used in the experiment", "--dataset",
) type=str,
parser.add_argument( help="Specify the dataset name used in the experiment",
"--root-dir", )
type=str, parser.add_argument(
help="The root directory of cropped-face dataset.", "--root-dir",
) type=str,
parser.add_argument("--job-index", type=int, default=0, help="job index") help="The root directory of cropped-face dataset.",
parser.add_argument( )
"--groups", parser.add_argument(
type=int, "--subset",
default=1, type=str,
help="specify the number of threads to be used", required=True,
) help="Subset of the dataset used in the experiment.",
parser.add_argument( )
"--folder", parser.add_argument(
type=str, "--seg-duration",
default="test", type=int,
help="specify the set used in the experiment", default=16,
) help="Length of the segment in seconds.",
args = parser.parse_args() )
return args parser.add_argument(
"--groups",
type=int,
args = load_args() default=1,
help="Number of threads to be used in parallel.",
seg_duration = 16 )
detector = "retinaface" parser.add_argument(
"--job-index",
type=int,
default=0,
help="Index to identify separate jobs (useful for parallel processing).",
)
args = parser.parse_args()
seg_duration = args.seg_duration
dataset = args.dataset dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir) args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(modality="video", detector=detector, resize=(96, 96)) vid_dataloader = AVSRDataLoader(modality="video", detector=args.detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio") aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments. # Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25 seg_vid_len = seg_duration * 25
...@@ -66,9 +74,9 @@ seg_aud_len = seg_duration * 16000 ...@@ -66,9 +74,9 @@ seg_aud_len = seg_duration * 16000
label_filename = os.path.join( label_filename = os.path.join(
args.root_dir, args.root_dir,
"labels", "labels",
f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.csv" f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1 if args.groups <= 1
else f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv", else f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
) )
os.makedirs(os.path.dirname(label_filename), exist_ok=True) os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created") print(f"Directory {os.path.dirname(label_filename)} created")
...@@ -77,9 +85,9 @@ f = open(label_filename, "w") ...@@ -77,9 +85,9 @@ f = open(label_filename, "w")
# Step 2, extract mouth patches from segments. # Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s") dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s") dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
if args.folder == "test": if args.subset == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.folder, "**", "*.mp4"), recursive=True) filenames = glob.glob(os.path.join(args.data_dir, args.subset, "**", "*.mp4"), recursive=True)
elif args.folder == "train": elif args.subset == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True) filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True)) filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort() filenames.sort()
...@@ -96,7 +104,7 @@ for data_filename in tqdm(filenames): ...@@ -96,7 +104,7 @@ for data_filename in tqdm(filenames):
except UnboundLocalError: except UnboundLocalError:
continue continue
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test", "main"]: if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4" dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav" dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt" dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
......
tqdm
scikit-image scikit-image
opencv-python opencv-python
ffmpeg-python ffmpeg-python
## Face Recognition ## Face Recognition
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. You can install directly from github repositories or by using compressed files.
### Option 1. Install from github repositories
### Prerequisites
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB. * [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges. You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
### From source
1. Install *`ibug.face_detection`*
```Shell ```Shell
git clone https://github.com/hhj1897/face_detection.git git clone https://github.com/hhj1897/face_detection.git
cd face_detection cd face_detection
...@@ -16,3 +14,15 @@ git lfs pull ...@@ -16,3 +14,15 @@ git lfs pull
pip install -e . pip install -e .
cd .. cd ..
``` ```
### Option 2. Install by using compressed files
If you are experiencing over-quota issues for the above repositoies, you can download both packages [ibug.face_detection](https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip), unzip the files, and then run `pip install -e .` to install each package.
```Shell
wget https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip -O ./face_detection.zip
unzip -o ./face_detection.zip -d ./
cd face_detection
pip install -e .
cd ..
```
...@@ -16,7 +16,7 @@ def compute_word_level_distance(seq1, seq2): ...@@ -16,7 +16,7 @@ def compute_word_level_distance(seq1, seq2):
def get_lightning_module(args): def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path)) sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av": if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model) model = AVConformerRNNTModule(args, sp_model)
...@@ -49,7 +49,7 @@ def run_eval(model, data_module): ...@@ -49,7 +49,7 @@ def run_eval(model, data_module):
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"--md", "--modality",
type=str, type=str,
help="Modality", help="Modality",
required=True, required=True,
...@@ -69,20 +69,15 @@ def parse_args(): ...@@ -69,20 +69,15 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sp-model-path", "--sp-model-path",
type=str, type=str,
help="Path to SentencePiece model.", help="Path to sentencepiece model.",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
"--checkpoint-path", "--checkpoint-path",
type=str, type=str,
help="Path to checkpoint model.", help="Path to a checkpoint model.",
required=True, required=True,
) )
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args() return parser.parse_args()
......
...@@ -57,9 +57,9 @@ class ConformerRNNTModule(LightningModule): ...@@ -57,9 +57,9 @@ class ConformerRNNTModule(LightningModule):
) )
self.blank_idx = spm_vocab_size self.blank_idx = spm_vocab_size
if args.md == "v": if args.modality == "video":
self.frontend = video_resnet() self.frontend = video_resnet()
if args.md == "a": if args.modality == "audio":
self.frontend = audio_resnet() self.frontend = audio_resnet()
if args.mode == "online": if args.mode == "online":
...@@ -116,33 +116,13 @@ class ConformerRNNTModule(LightningModule): ...@@ -116,33 +116,13 @@ class ConformerRNNTModule(LightningModule):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}], [{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
) )
def forward(self, batch: Batch): def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx) decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device)) x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20) hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0] return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx): def training_step(self, batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers() opt = self.optimizers()
opt.zero_grad() opt.zero_grad()
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
...@@ -157,7 +137,7 @@ class ConformerRNNTModule(LightningModule): ...@@ -157,7 +137,7 @@ class ConformerRNNTModule(LightningModule):
sch = self.lr_schedulers() sch = self.lr_schedulers()
sch.step() sch.step()
self.log("monitoring_step", self.global_step) self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss return loss
......
...@@ -116,7 +116,7 @@ class AVConformerRNNTModule(LightningModule): ...@@ -116,7 +116,7 @@ class AVConformerRNNTModule(LightningModule):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}], [{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
) )
def forward(self, batch: AVBatch): def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx) decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device)) video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device)) audio_features = self.audio_frontend(batch.audios.to(self.device))
...@@ -127,27 +127,7 @@ class AVConformerRNNTModule(LightningModule): ...@@ -127,27 +127,7 @@ class AVConformerRNNTModule(LightningModule):
) )
return post_process_hypos(hypotheses, self.sp_model)[0][0] return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: AVBatch, batch_idx): def training_step(self, batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers() opt = self.optimizers()
opt.zero_grad() opt.zero_grad()
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
...@@ -162,7 +142,7 @@ class AVConformerRNNTModule(LightningModule): ...@@ -162,7 +142,7 @@ class AVConformerRNNTModule(LightningModule):
sch = self.lr_schedulers() sch = self.lr_schedulers()
sch.step() sch.step()
self.log("monitoring_step", self.global_step) self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss return loss
......
...@@ -40,12 +40,12 @@ def load_transcript(path): ...@@ -40,12 +40,12 @@ def load_transcript(path):
return open(transcript_path).read().splitlines()[0] return open(transcript_path).read().splitlines()[0]
def load_item(path, md): def load_item(path, modality):
if md == "v": if modality == "video":
return (load_video(path), load_transcript(path)) return (load_video(path), load_transcript(path))
if md == "a": if modality == "audio":
return (load_audio(path), load_transcript(path)) return (load_audio(path), load_transcript(path))
if md == "av": if modality == "audiovisual":
return (load_audio(path), load_video(path), load_transcript(path)) return (load_audio(path), load_video(path), load_transcript(path))
...@@ -62,15 +62,15 @@ class LRS3(Dataset): ...@@ -62,15 +62,15 @@ class LRS3(Dataset):
self.args = args self.args = args
if subset == "train": if subset == "train":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv") self.files, self.lengths = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val": if subset == "val":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv") self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test": if subset == "test":
self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv") self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n): def __getitem__(self, n):
path = self._filelist[n] path = self.files[n]
return load_item(path, self.args.md) return load_item(path, self.args.modality)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._filelist) return len(self.files)
...@@ -32,5 +32,5 @@ class FeedForwardModule(torch.nn.Module): ...@@ -32,5 +32,5 @@ class FeedForwardModule(torch.nn.Module):
return self.sequential(input) return self.sequential(input)
def fusion_module(): def fusion_module(input_dim=1024, hidden_dim=3072, output_dim=512, dropout=0.1):
return FeedForwardModule(1024, 3072, 512, 0.1) return FeedForwardModule(input_dim, hidden_dim, output_dim, dropout)
...@@ -14,7 +14,7 @@ def get_trainer(args): ...@@ -14,7 +14,7 @@ def get_trainer(args):
seed_everything(1) seed_everything(1)
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
dirpath=os.path.join(args.exp_dir, args.experiment_name) if args.exp_dir else None, dirpath=os.path.join(args.exp_dir, args.exp_name) if args.exp_dir else None,
monitor="monitoring_step", monitor="monitoring_step",
mode="max", mode="max",
save_last=True, save_last=True,
...@@ -36,13 +36,12 @@ def get_trainer(args): ...@@ -36,13 +36,12 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False), strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
resume_from_checkpoint=args.resume_from_checkpoint,
) )
def get_lightning_module(args): def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path)) sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av": if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model) model = AVConformerRNNTModule(args, sp_model)
...@@ -56,7 +55,7 @@ def get_lightning_module(args): ...@@ -56,7 +55,7 @@ def get_lightning_module(args):
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"--md", "--modality",
type=str, type=str,
help="Modality", help="Modality",
required=True, required=True,
...@@ -86,19 +85,20 @@ def parse_args(): ...@@ -86,19 +85,20 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
default="./exp",
type=str, type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')", help="Directory to save checkpoints and logs to. (Default: './exp')",
) )
parser.add_argument( parser.add_argument(
"--experiment-name", "--exp-name",
type=str, type=str,
help="Experiment name", help="Experiment name",
) )
parser.add_argument( parser.add_argument(
"--num-nodes", "--num-nodes",
default=8, default=4,
type=int, type=int,
help="Number of nodes to use for training. (Default: 8)", help="Number of nodes to use for training. (Default: 4)",
) )
parser.add_argument( parser.add_argument(
"--gpus", "--gpus",
...@@ -113,9 +113,16 @@ def parse_args(): ...@@ -113,9 +113,16 @@ def parse_args():
help="Number of epochs to train for. (Default: 55)", help="Number of epochs to train for. (Default: 55)",
) )
parser.add_argument( parser.add_argument(
"--resume-from-checkpoint", default=None, type=str, help="Path to the checkpoint to resume from" "--resume-from-checkpoint",
default=None,
type=str,
help="Path to the checkpoint to resume from",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to use debug level for logging",
) )
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args() return parser.parse_args()
......
...@@ -55,28 +55,28 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args): ...@@ -55,28 +55,28 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = [] raw_videos = []
raw_audios = [] raw_audios = []
for sample in samples: for sample in samples:
if args.md == "v": if args.modality == "visual":
raw_videos.append(sample[0]) raw_videos.append(sample[0])
if args.md == "a": if args.modality == "audio":
raw_audios.append(sample[0]) raw_audios.append(sample[0])
if args.md == "av": if args.modality == "audiovisual":
length = min(len(sample[0]) // 640, len(sample[1])) length = min(len(sample[0]) // 640, len(sample[1]))
raw_audios.append(sample[0][: length * 640]) raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length]) raw_videos.append(sample[1][:length])
if args.md == "v" or args.md == "av": if args.modality == "visual" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True) videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos) videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32) video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
if args.md == "a" or args.md == "av": if args.modality == "audio" or args.modality == "audiovisual":
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True) audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios) audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32) audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.md == "v": if args.modality == "visual":
return videos, video_lengths return videos, video_lengths
if args.md == "a": if args.modality == "audio":
return audios, audio_lengths return audios, audio_lengths
if args.md == "av": if args.modality == "audiovisual":
return audios, videos, audio_lengths, video_lengths return audios, videos, audio_lengths, video_lengths
...@@ -100,17 +100,17 @@ class TrainTransform: ...@@ -100,17 +100,17 @@ class TrainTransform:
def __call__(self, samples: List): def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples) targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a": if self.args.modality == "audio":
audios, audio_lengths = _extract_features( audios, audio_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
) )
return Batch(audios, audio_lengths, targets, target_lengths) return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v": if self.args.modality == "visual":
videos, video_lengths = _extract_features( videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
) )
return Batch(videos, video_lengths, targets, target_lengths) return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av": if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features( audios, videos, audio_lengths, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
) )
...@@ -135,17 +135,17 @@ class ValTransform: ...@@ -135,17 +135,17 @@ class ValTransform:
def __call__(self, samples: List): def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples) targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a": if self.args.modality == "audio":
audios, audio_lengths = _extract_features( audios, audio_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
) )
return Batch(audios, audio_lengths, targets, target_lengths) return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v": if self.args.modality == "visual":
videos, video_lengths = _extract_features( videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
) )
return Batch(videos, video_lengths, targets, target_lengths) return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av": if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features( audios, videos, audio_lengths, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment