Commit c977afe0 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

av-asr: move video loading outside detector (#3498)

Summary:
This PR moves video loading outside detector during pre-processing.

Pull Request resolved: https://github.com/pytorch/audio/pull/3498

Reviewed By: mthrok

Differential Revision: D47811044

Pulled By: mpc001

fbshipit-source-id: f17839b695b13d3cf2d9db343d7e9a0202eea7d5
parent da212020
......@@ -32,8 +32,8 @@ class AVSRDataLoader:
audio = self.audio_process(audio, sample_rate)
return audio
if self.modality == "video":
landmarks = self.landmarks_detector(data_filename)
video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return video
......
......@@ -9,7 +9,6 @@ import warnings
import mediapipe as mp
import numpy as np
import torchvision
warnings.filterwarnings("ignore")
......@@ -29,8 +28,7 @@ class LandmarksDetector:
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks
def detect(self, filename, detector):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
def detect(self, video_frames, detector):
landmarks = []
for frame in video_frames:
results = detector.process(frame)
......
......@@ -7,7 +7,6 @@
import warnings
import numpy as np
import torchvision
from ibug.face_detection import RetinaFacePredictor
warnings.filterwarnings("ignore")
......@@ -19,8 +18,7 @@ class LandmarksDetector:
device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name)
)
def __call__(self, filename):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
def __call__(self, video_frames):
landmarks = []
for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment