#! /usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 Imperial College London (Pingchuan Ma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import torch import torchaudio import torchvision class AVSRDataLoader: def __init__(self, modality, detector="retinaface", resize=None): self.modality = modality if modality == "video": if detector == "retinaface": from detectors.retinaface.detector import LandmarksDetector from detectors.retinaface.video_process import VideoProcess self.landmarks_detector = LandmarksDetector(device="cuda:0") self.video_process = VideoProcess(resize=resize) def load_data(self, data_filename, transform=True): if self.modality == "audio": audio, sample_rate = self.load_audio(data_filename) audio = self.audio_process(audio, sample_rate) return audio if self.modality == "video": landmarks = self.landmarks_detector(data_filename) video = self.load_video(data_filename) video = self.video_process(video, landmarks) video = torch.tensor(video) return video def load_audio(self, data_filename): waveform, sample_rate = torchaudio.load(data_filename, normalize=True) return waveform, sample_rate def load_video(self, data_filename): return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy() def audio_process(self, waveform, sample_rate, target_sample_rate=16000): if sample_rate != target_sample_rate: waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate) waveform = torch.mean(waveform, dim=0, keepdim=True) return waveform