You need to sign in or sign up before continuing.
Unverified Commit 47f502a6 authored by Oren Amsalem's avatar Oren Amsalem Committed by GitHub
Browse files

change modality naming inconsistency (visual & video) (#3631)

* Update transforms.py

* Update train.py
parent 92ded610
...@@ -59,6 +59,7 @@ def parse_args(): ...@@ -59,6 +59,7 @@ def parse_args():
"--modality", "--modality",
type=str, type=str,
help="Modality", help="Modality",
choices=["audio", "video", "audiovisual"],
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
......
...@@ -55,7 +55,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args): ...@@ -55,7 +55,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = [] raw_videos = []
raw_audios = [] raw_audios = []
for sample in samples: for sample in samples:
if args.modality == "visual": if args.modality == "video":
raw_videos.append(sample[0]) raw_videos.append(sample[0])
if args.modality == "audio": if args.modality == "audio":
raw_audios.append(sample[0]) raw_audios.append(sample[0])
...@@ -64,7 +64,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args): ...@@ -64,7 +64,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_audios.append(sample[0][: length * 640]) raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length]) raw_videos.append(sample[1][:length])
if args.modality == "visual" or args.modality == "audiovisual": if args.modality == "video" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True) videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos) videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32) video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
...@@ -72,7 +72,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args): ...@@ -72,7 +72,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True) audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios) audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32) audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.modality == "visual": if args.modality == "video":
return videos, video_lengths return videos, video_lengths
if args.modality == "audio": if args.modality == "audio":
return audios, audio_lengths return audios, audio_lengths
...@@ -105,7 +105,7 @@ class TrainTransform: ...@@ -105,7 +105,7 @@ class TrainTransform:
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
) )
return Batch(audios, audio_lengths, targets, target_lengths) return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual": if self.args.modality == "video":
videos, video_lengths = _extract_features( videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
) )
...@@ -140,7 +140,7 @@ class ValTransform: ...@@ -140,7 +140,7 @@ class ValTransform:
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
) )
return Batch(audios, audio_lengths, targets, target_lengths) return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual": if self.args.modality == "video":
videos, video_lengths = _extract_features( videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment