Commit c977afe0 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

av-asr: move video loading outside detector (#3498)

Summary:
This PR moves video loading outside detector during pre-processing.

Pull Request resolved: https://github.com/pytorch/audio/pull/3498

Reviewed By: mthrok

Differential Revision: D47811044

Pulled By: mpc001

fbshipit-source-id: f17839b695b13d3cf2d9db343d7e9a0202eea7d5
parent da212020
...@@ -32,8 +32,8 @@ class AVSRDataLoader: ...@@ -32,8 +32,8 @@ class AVSRDataLoader:
audio = self.audio_process(audio, sample_rate) audio = self.audio_process(audio, sample_rate)
return audio return audio
if self.modality == "video": if self.modality == "video":
landmarks = self.landmarks_detector(data_filename)
video = self.load_video(data_filename) video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks) video = self.video_process(video, landmarks)
video = torch.tensor(video) video = torch.tensor(video)
return video return video
......
...@@ -9,7 +9,6 @@ import warnings ...@@ -9,7 +9,6 @@ import warnings
import mediapipe as mp import mediapipe as mp
import numpy as np import numpy as np
import torchvision
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
...@@ -29,8 +28,7 @@ class LandmarksDetector: ...@@ -29,8 +28,7 @@ class LandmarksDetector:
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video" assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks return landmarks
def detect(self, filename, detector): def detect(self, video_frames, detector):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = [] landmarks = []
for frame in video_frames: for frame in video_frames:
results = detector.process(frame) results = detector.process(frame)
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import warnings import warnings
import numpy as np import numpy as np
import torchvision
from ibug.face_detection import RetinaFacePredictor from ibug.face_detection import RetinaFacePredictor
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
...@@ -19,8 +18,7 @@ class LandmarksDetector: ...@@ -19,8 +18,7 @@ class LandmarksDetector:
device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name) device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name)
) )
def __call__(self, filename): def __call__(self, video_frames):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = [] landmarks = []
for frame in video_frames: for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False) detected_faces = self.face_detector(frame, rgb=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment