Unverified Commit 47f502a6 authored by Oren Amsalem's avatar Oren Amsalem Committed by GitHub
Browse files

change modality naming inconsistency (visual & video) (#3631)

* Update transforms.py

* Update train.py
parent 92ded610
......@@ -59,6 +59,7 @@ def parse_args():
"--modality",
type=str,
help="Modality",
choices=["audio", "video", "audiovisual"],
required=True,
)
parser.add_argument(
......
......@@ -55,7 +55,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = []
raw_audios = []
for sample in samples:
if args.modality == "visual":
if args.modality == "video":
raw_videos.append(sample[0])
if args.modality == "audio":
raw_audios.append(sample[0])
......@@ -64,7 +64,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length])
if args.modality == "visual" or args.modality == "audiovisual":
if args.modality == "video" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
......@@ -72,7 +72,7 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.modality == "visual":
if args.modality == "video":
return videos, video_lengths
if args.modality == "audio":
return audios, audio_lengths
......@@ -105,7 +105,7 @@ class TrainTransform:
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual":
if self.args.modality == "video":
videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
......@@ -140,7 +140,7 @@ class ValTransform:
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual":
if self.args.modality == "video":
videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment