Commit 77cdd160 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

Add LRS3 data preparation (#3421)

Summary:
This PR adds a data preparation recipe that uses the ultra face detector to extract full-face video. The resulting video output is then used as input for training and evaluating RNNT-based models for automatic speech recognition (ASR), visual speech recognition (VSR), and audio-visual ASR (AV-ASR) on the LRS3 dataset.

This PR also updates the word error rate (WER) for AV-ASR LRS3 models and improves the code readability.

Pull Request resolved: https://github.com/pytorch/audio/pull/3421

Reviewed By: mpc001

Differential Revision: D46799748

Pulled By: mthrok

fbshipit-source-id: 97af3feac0592b240617faaffa4c0ac8cef614a9
parent 18601691
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p> <p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p>
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1> <h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1>
This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307).
## Preparation ## Preparation
1. Setup the environment. 1. Setup the environment.
...@@ -18,20 +18,18 @@ pip install pytorch-lightning sentencepiece ...@@ -18,20 +18,18 @@ pip install pytorch-lightning sentencepiece
3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder. 3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder.
4. Download models below to initialise ASR/VSR front-end. 4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`.
### Training A/V-ASR model ### Training ASR or VSR model
- `[dataset_path]` is the directory for original dataset. - `[root_dir]` is the root directory for the LRS3 cropped-face dataset.
- `[label_path]` is the labels directory.
- `[modality]` is the input modality type, including `v`, `a`, and `av`. - `[modality]` is the input modality type, including `v`, `a`, and `av`.
- `[mode]` is the model type, including `online` and `offline`. - `[mode]` is the model type, including `online` and `offline`.
```Shell ```Shell
python train.py --dataset-path [dataset_path] \ python train.py --root-dir [root_dir] \
--label-path [label-path]
--pretrained-model-path [pretrained_model_path] \
--sp-model-path ./spm_unigram_1023.model --sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \ --exp-dir ./exp \
--num-nodes 8 \ --num-nodes 8 \
...@@ -43,10 +41,7 @@ python train.py --dataset-path [dataset_path] \ ...@@ -43,10 +41,7 @@ python train.py --dataset-path [dataset_path] \
### Training AV-ASR model ### Training AV-ASR model
```Shell ```Shell
python train.py --dataset-path [dataset_path] \ python train.py --root-dir [root-dir] \
--label-path [label-path]
--pretrained-vid-model-path [pretrained_vid_model_path] \
--pretrained-aud-model-path [pretrained_aud_model_path] \
--sp-model-path ./spm_unigram_1023.model --sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \ --exp-dir ./exp \
--num-nodes 8 \ --num-nodes 8 \
...@@ -59,19 +54,17 @@ python train.py --dataset-path [dataset_path] \ ...@@ -59,19 +54,17 @@ python train.py --dataset-path [dataset_path] \
```Shell ```Shell
python eval.py --dataset-path [dataset_path] \ python eval.py --dataset-path [dataset_path] \
--label-path [label-path]
--pretrained-model-path [pretrained_model_path] \
--sp-model-path ./spm_unigram_1023.model --sp-model-path ./spm_unigram_1023.model
--md [modality] \ --md [modality] \
--mode [mode] \ --mode [mode] \
--checkpoint-path [checkpoint_path] --checkpoint-path [checkpoint_path]
``` ```
The table below contains WER for AV-ASR models. The table below contains WER for AV-ASR models [offline evaluation].
| Model | WER [%] | Params (M) | | Model | WER [%] | Params (M) |
|:-----------:|:------------:|:--------------:| |:-----------:|:------------:|:--------------:|
| Non-streaming models | | | Non-streaming models | |
| AV-ASR | 4.2 | 50 | | AV-ASR | 4.0 | 50 |
| Streaming models | | | Streaming models | |
| AV-ASR | 4.9 | 40 | | AV-ASR | 4.3 | 40 |
# Preprocessing LRS3
We provide a pre-processing pipeline to detect and crop full-face images in this repository. We provide a pre-processing pipeline to detect and crop full-face images in this repository.
## Prerequisites
Install all dependency-packages.
```Shell
pip install -r requirements.txt
```
Install [RetinaFace](./tools) tracker.
## Preprocessing
### Step 1. Pre-process the LRS3 dataset.
Please run the following script to pre-process the LRS3 dataset:
```Shell
python main.py \
--data-dir=[data_dir] \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
--job-index=[job_index]
```
- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks.
- `[root]` is the directory for saved cropped-face dataset.
- `[folder]` can be set to `train` or `test`.
- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`).
### Step 2. Merge the label list.
After completing Step 2, run the following script to merge all labels.
```Shell
python merge.py \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
```
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torchaudio
import torchvision
class AVSRDataLoader:
def __init__(self, modality, detector="retinaface", resize=None):
self.modality = modality
if modality == "video":
if detector == "retinaface":
from detectors.retinaface.detector import LandmarksDetector
from detectors.retinaface.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True):
if self.modality == "audio":
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
return audio
if self.modality == "video":
landmarks = self.landmarks_detector(data_filename)
video = self.load_video(data_filename)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return video
def load_audio(self, data_filename):
waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
return waveform, sample_rate
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()
def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
if sample_rate != target_sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import numpy as np
import torchvision
from ibug.face_detection import RetinaFacePredictor
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self, device="cuda:0", model_name="resnet50"):
self.face_detector = RetinaFacePredictor(
device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name)
)
def __call__(self, filename):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = []
for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False)
if len(detected_faces) >= 1:
landmarks.append(np.reshape(detected_faces[0][:4], (2, 2)))
else:
landmarks.append(None)
return landmarks
import cv2
import numpy as np
from skimage import transform as tf
def linear_interpolate(landmarks, start_idx, stop_idx):
start_landmarks = landmarks[start_idx]
stop_landmarks = landmarks[stop_idx]
delta = stop_landmarks - start_landmarks
for idx in range(1, stop_idx - start_idx):
landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
return landmarks
def warp_img(src, dst, img, std_size):
tform = tf.estimate_transform("similarity", src, dst)
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped, tform
def apply_transform(transform, img, std_size):
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
warped = (warped * 255).astype("uint8")
return warped
def cut_patch(img, landmarks, height, width, threshold=5):
center_x, center_y = np.mean(landmarks, axis=0)
# Check for too much bias in height and width
if abs(center_y - img.shape[0] / 2) > height + threshold:
raise Exception("too much bias in height")
if abs(center_x - img.shape[1] / 2) > width + threshold:
raise Exception("too much bias in width")
# Calculate bounding box coordinates
y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
# Cut the image
cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
return cutted_img
class VideoProcess:
def __init__(
self,
crop_width=128,
crop_height=128,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
start_idx=0,
stop_idx=2,
resize=(96, 96),
):
self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
self.crop_width = crop_width
self.crop_height = crop_height
self.start_idx = start_idx
self.stop_idx = stop_idx
self.resize = resize
def __call__(self, video, landmarks):
# Pre-process landmarks: interpolate frames that are not detected
preprocessed_landmarks = self.interpolate_landmarks(landmarks)
# Exclude corner cases: no landmark in all frames or number of frames is less than window length
if not preprocessed_landmarks:
return
# Affine transformation and crop patch
sequence = self.crop_patch(video, preprocessed_landmarks)
assert sequence is not None, "crop an empty patch."
return sequence
def crop_patch(self, video, landmarks):
sequence = []
for frame_idx, frame in enumerate(video):
transformed_frame, transformed_landmarks = self.affine_transform(
frame, landmarks[frame_idx], self.reference
)
patch = cut_patch(
transformed_frame,
transformed_landmarks[self.start_idx : self.stop_idx],
self.crop_height // 2,
self.crop_width // 2,
)
if self.resize:
patch = cv2.resize(patch, self.resize)
sequence.append(patch)
return np.array(sequence)
def interpolate_landmarks(self, landmarks):
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
if not valid_frames_idx:
return None
for idx in range(1, len(valid_frames_idx)):
if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
# Handle corner case: keep frames at the beginning or at the end that failed to be detected
if valid_frames_idx:
landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
len(landmarks) - valid_frames_idx[-1]
)
assert all(lm is not None for lm in landmarks), "not every frame has landmark"
return landmarks
def affine_transform(
self,
frame,
landmarks,
reference,
target_size=(224, 224),
reference_size=(224, 224),
stable_points=(0, 1),
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_CONSTANT,
border_value=0,
):
stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
transformed_frame, transformed_landmarks = self.apply_affine_transform(
frame, landmarks, transform, target_size, interpolation, border_mode, border_value
)
return transformed_frame, transformed_landmarks
def get_stable_reference(self, reference, stable_points, reference_size, target_size):
stable_reference = np.vstack([reference[x] for x in stable_points])
stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
return stable_reference
def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
return cv2.estimateAffinePartial2D(
np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
)[0]
def apply_affine_transform(
self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
):
transformed_frame = cv2.warpAffine(
frame,
transform,
dsize=(target_size[0], target_size[1]),
flags=interpolation,
borderMode=border_mode,
borderValue=border_value,
)
transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
return transformed_frame, transformed_landmarks
import glob
import math
import os
import shutil
import warnings
import ffmpeg
from data.data_module import AVSRDataLoader
from tqdm import tqdm
from utils import save_vid_aud_txt, split_file
warnings.filterwarnings("ignore")
from argparse import ArgumentParser
def load_args(default_config=None):
parser = ArgumentParser(description="Preprocess LRS3 to crop full-face images")
# -- for benchmark evaluation
parser.add_argument(
"--data-dir",
type=str,
help="The directory for sequence.",
)
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of cropped-face dataset.",
)
parser.add_argument("--job-index", type=int, default=0, help="job index")
parser.add_argument(
"--groups",
type=int,
default=1,
help="specify the number of threads to be used",
)
parser.add_argument(
"--folder",
type=str,
default="test",
help="specify the set used in the experiment",
)
args = parser.parse_args()
return args
args = load_args()
seg_duration = 16
detector = "retinaface"
dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(modality="video", detector=detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25
seg_aud_len = seg_duration * 16000
label_filename = os.path.join(
args.root_dir,
"labels",
f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1
else f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
)
os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created")
f = open(label_filename, "w")
# Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
if args.folder == "test":
filenames = glob.glob(os.path.join(args.data_dir, args.folder, "**", "*.mp4"), recursive=True)
elif args.folder == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort()
else:
raise NotImplementedError
unit = math.ceil(len(filenames) * 1.0 / args.groups)
filenames = filenames[args.job_index * unit : (args.job_index + 1) * unit]
for data_filename in tqdm(filenames):
try:
video_data = vid_dataloader.load_data(data_filename)
audio_data = aud_dataloader.load_data(data_filename)
except UnboundLocalError:
continue
if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test", "main"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
trim_vid_data, trim_aud_data = video_data, audio_data
text_line_list = open(data_filename[:-4] + ".txt", "r").read().splitlines()[0].split(" ")
text_line = " ".join(text_line_list[2:])
content = text_line.replace("}", "").replace("{", "")
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
continue
splitted = split_file(data_filename[:-4] + ".txt", max_frames=seg_vid_len)
for i in range(len(splitted)):
if len(splitted) == 1:
content, start, end, duration = splitted[i]
trim_vid_data, trim_aud_data = video_data, audio_data
else:
content, start, end, duration = splitted[i]
start_idx, end_idx = int(start * 25), int(end * 25)
try:
trim_vid_data, trim_aud_data = (
video_data[start_idx:end_idx],
audio_data[:, start_idx * 640 : end_idx * 640],
)
except TypeError:
continue
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}_{i:02d}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}_{i:02d}.txt"
if trim_vid_data is None or trim_aud_data is None:
continue
video_length = len(trim_vid_data)
audio_length = trim_aud_data.size(1)
if video_length == 0 or audio_length == 0:
continue
if audio_length / video_length < 560.0 or audio_length / video_length > 720.0 or video_length < 12:
continue
save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
)
in1 = ffmpeg.input(dst_vid_filename)
in2 = ffmpeg.input(dst_aud_filename)
out = ffmpeg.output(
in1["v"],
in2["a"],
dst_vid_filename[:-4] + ".m.mp4",
vcodec="copy",
acodec="aac",
strict="experimental",
loglevel="panic",
)
out.run()
os.remove(dst_aud_filename)
os.remove(dst_vid_filename)
shutil.move(dst_vid_filename[:-4] + ".m.mp4", dst_vid_filename)
basename = os.path.relpath(dst_vid_filename, start=os.path.join(args.root_dir, dataset))
f.write("{}\n".format(f"{dataset},{basename},{trim_vid_data.shape[0]},{len(content)}"))
f.close()
import os
from argparse import ArgumentParser
def load_args(default_config=None):
parser = ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="Specify the dataset name used in the experiment",
)
parser.add_argument(
"--subset",
type=str,
help="Specify the set used in the experiment",
)
parser.add_argument(
"--root-dir",
type=str,
help="The root directory of saved mouth patches or embeddings.",
)
parser.add_argument(
"--groups",
type=int,
help="Specify the number of threads to be used",
)
parser.add_argument(
"--seg-duration",
type=int,
default=16,
help="Specify the segment length",
)
args = parser.parse_args()
return args
args = load_args()
dataset = args.dataset
subset = args.subset
seg_duration = args.seg_duration
# Check that there is more than one group
assert args.groups > 1, "There is no need to use this script for merging when --groups is 1."
# Create the filename template for label files
label_template = os.path.join(
args.root_dir, "labels", f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}"
)
lines = []
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
assert os.path.exists(label_filename), f"{label_filename} does not exist."
with open(label_filename, "r") as file:
lines.extend(file.read().splitlines())
# Write the merged labels to a new file
dst_label_filename = os.path.join(
args.root_dir, dataset, f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv"
)
with open(dst_label_filename, "w") as file:
file.write("\n".join(lines))
# Print the number of files and total duration in hours
total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0
print(f"The completed set has {len(lines)} files with a total of {total_duration} hours.")
# Remove the label files for each job index
print("** Remove the temporary label files **")
for job_index in range(args.groups):
label_filename = f"{label_template}.{job_index}.csv"
if os.path.exists(label_filename):
os.remove(label_filename)
print("** Finish **")
scikit-image
opencv-python
ffmpeg-python
## Face Recognition
We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository.
### Prerequisites
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
### From source
1. Install *`ibug.face_detection`*
```Shell
git clone https://github.com/hhj1897/face_detection.git
cd face_detection
git lfs pull
pip install -e .
cd ..
```
import os
import torchaudio
import torchvision
def split_file(filename, max_frames=600, fps=25.0):
lines = open(filename).read().splitlines()
flag = 0
stack = []
res = []
tmp = 0
start_timestamp = 0.0
threshold = max_frames / fps
for line in lines:
if "WORD START END ASDSCORE" in line:
flag = 1
continue
if flag:
word, start, end, score = line.split(" ")
start, end, score = float(start), float(end), float(score)
if end < tmp + threshold:
stack.append(word)
last_timestamp = end
else:
res.append([" ".join(stack), start_timestamp, last_timestamp, last_timestamp - start_timestamp])
tmp = start
start_timestamp = start
stack = [word]
if stack:
res.append([" ".join(stack), start_timestamp, end, end - start_timestamp])
return res
def save_vid_txt(dst_vid_filename, dst_txt_filename, trim_video_data, content, video_fps=25):
# -- save video
save2vid(dst_vid_filename, trim_video_data, video_fps)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save_vid_aud(
dst_vid_filename, dst_aud_filename, trim_vid_data, trim_aud_data, video_fps=25, audio_sample_rate=16000
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
def save_vid_aud_txt(
dst_vid_filename,
dst_aud_filename,
dst_txt_filename,
trim_vid_data,
trim_aud_data,
content,
video_fps=25,
audio_sample_rate=16000,
):
# -- save video
save2vid(dst_vid_filename, trim_vid_data, video_fps)
# -- save audio
save2aud(dst_aud_filename, trim_aud_data, audio_sample_rate)
# -- save text
os.makedirs(os.path.dirname(dst_txt_filename), exist_ok=True)
f = open(dst_txt_filename, "w")
f.write(f"{content}")
f.close()
def save2vid(filename, vid, frames_per_second):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchvision.io.write_video(filename, vid, frames_per_second)
def save2aud(filename, aud, sample_rate):
os.makedirs(os.path.dirname(filename), exist_ok=True)
torchaudio.save(filename, aud, sample_rate)
...@@ -61,9 +61,9 @@ def parse_args(): ...@@ -61,9 +61,9 @@ def parse_args():
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
"--dataset-path", "--root-dir",
type=str, type=str,
help="Path to LRW audio-visual datasets.", help="Root directory to LRS3 audio-visual datasets.",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
......
...@@ -9,10 +9,10 @@ def _load_list(args, *filenames): ...@@ -9,10 +9,10 @@ def _load_list(args, *filenames):
output = [] output = []
length = [] length = []
for filename in filenames: for filename in filenames:
filepath = os.path.join(os.path.dirname(args.dataset_path), filename) filepath = os.path.join(args.root_dir, "labels", filename)
for line in open(filepath).read().splitlines(): for line in open(filepath).read().splitlines():
rel_path, input_length = line.split(",")[1:3] dataset, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2]
path = os.path.normpath(os.path.join(args.dataset_path, rel_path[:-4] + ".mp4")) path = os.path.normpath(os.path.join(args.root_dir, dataset, rel_path[:-4] + ".mp4"))
length.append(int(input_length)) length.append(int(input_length))
output.append(path) output.append(path)
return output, length return output, length
...@@ -62,11 +62,11 @@ class LRS3(Dataset): ...@@ -62,11 +62,11 @@ class LRS3(Dataset):
self.args = args self.args = args
if subset == "train": if subset == "train":
self._filelist, self._lengthlist = _load_list(self.args, "train_transcript_lengths_seg16s.csv") self._filelist, self._lengthlist = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val": if subset == "val":
self._filelist, self._lengthlist = _load_list(self.args, "test_transcript_lengths_seg16s.csv") self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test": if subset == "test":
self._filelist, self._lengthlist = _load_list(self.args, "test_transcript_lengths_seg16s.csv") self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n): def __getitem__(self, n):
path = self._filelist[n] path = self._filelist[n]
......
...@@ -18,7 +18,7 @@ def get_trainer(args): ...@@ -18,7 +18,7 @@ def get_trainer(args):
monitor="monitoring_step", monitor="monitoring_step",
mode="max", mode="max",
save_last=True, save_last=True,
filename="{{epoch}}", filename="{epoch}",
save_top_k=10, save_top_k=10,
) )
lr_monitor = LearningRateMonitor(logging_interval="step") lr_monitor = LearningRateMonitor(logging_interval="step")
...@@ -68,9 +68,9 @@ def parse_args(): ...@@ -68,9 +68,9 @@ def parse_args():
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
"--dataset-path", "--root-dir",
type=str, type=str,
help="Path to LRW audio-visual datasets.", help="Root directory to LRS3 audio-visual datasets.",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
...@@ -91,7 +91,6 @@ def parse_args(): ...@@ -91,7 +91,6 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--experiment-name", "--experiment-name",
default="online_avsr_public_test",
type=str, type=str,
help="Experiment name", help="Experiment name",
) )
......
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LRS3 pretrain and trainval. """Trains a SentencePiece model on transcripts across LRS3 pretrain and trainval.
- `[lrs3_path]` is the directory path for the LRS3 cropped face dataset.
Example: Example:
python train_spm.py --lrs3-path <LRS3-DIRECTORY> python train_spm.py --lrs3-path [lrs3_path]
""" """
import io import io
...@@ -62,7 +64,7 @@ def parse_args(): ...@@ -62,7 +64,7 @@ def parse_args():
def run_cli(): def run_cli():
args = parse_args() args = parse_args()
root = args.lrs3_path / "LRS3_text_seg24s" root = args.lrs3_path / "LRS3_text_seg16s"
splits = ["pretrain", "trainval"] splits = ["pretrain", "trainval"]
merged_transcripts = [] merged_transcripts = []
for split in splits: for split in splits:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment