transforms.py 6.79 KB
Newer Older
Pingchuan Ma's avatar
Pingchuan Ma committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
from typing import List

import sentencepiece as spm
import torch
import torchvision
from data_module import LRS3DataModule
from lightning import Batch
from lightning_av import AVBatch


class FunctionalModule(torch.nn.Module):
    def __init__(self, functional):
        super().__init__()
        self.functional = functional

    def forward(self, input):
        return self.functional(input)


class AdaptiveTimeMask(torch.nn.Module):
    def __init__(self, window, stride):
        super().__init__()
        self.window = window
        self.stride = stride

    def forward(self, x):
        cloned = x.clone()
        length = cloned.size(1)
        n_mask = int((length + self.stride - 0.1) // self.stride)
        ts = torch.randint(0, self.window, size=(n_mask, 2))
        for t, t_end in ts:
            if length - t <= 0:
                continue
            t_start = random.randrange(0, length - t)
            if t_start == t_start + t:
                continue
            t_end += t_start
            cloned[:, t_start:t_end] = 0
        return cloned


def _extract_labels(sp_model, samples: List):
    targets = [sp_model.encode(sample[-1].lower()) for sample in samples]
    lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
    targets = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(elem) for elem in targets],
        batch_first=True,
        padding_value=1.0,
    ).to(dtype=torch.int32)
    return targets, lengths


def _extract_features(video_pipeline, audio_pipeline, samples, args):
    raw_videos = []
    raw_audios = []
    for sample in samples:
58
        if args.modality == "video":
Pingchuan Ma's avatar
Pingchuan Ma committed
59
            raw_videos.append(sample[0])
Pingchuan Ma's avatar
Pingchuan Ma committed
60
        if args.modality == "audio":
Pingchuan Ma's avatar
Pingchuan Ma committed
61
            raw_audios.append(sample[0])
Pingchuan Ma's avatar
Pingchuan Ma committed
62
        if args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
63
64
65
66
            length = min(len(sample[0]) // 640, len(sample[1]))
            raw_audios.append(sample[0][: length * 640])
            raw_videos.append(sample[1][:length])

67
    if args.modality == "video" or args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
68
69
70
        videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
        videos = video_pipeline(videos)
        video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
Pingchuan Ma's avatar
Pingchuan Ma committed
71
    if args.modality == "audio" or args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
72
73
74
        audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
        audios = audio_pipeline(audios)
        audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
75
    if args.modality == "video":
Pingchuan Ma's avatar
Pingchuan Ma committed
76
        return videos, video_lengths
Pingchuan Ma's avatar
Pingchuan Ma committed
77
    if args.modality == "audio":
Pingchuan Ma's avatar
Pingchuan Ma committed
78
        return audios, audio_lengths
Pingchuan Ma's avatar
Pingchuan Ma committed
79
    if args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        return audios, videos, audio_lengths, video_lengths


class TrainTransform:
    def __init__(self, sp_model_path: str, args):
        self.args = args
        self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
        self.train_video_pipeline = torch.nn.Sequential(
            FunctionalModule(lambda x: x / 255.0),
            torchvision.transforms.RandomCrop(88),
            torchvision.transforms.RandomHorizontalFlip(0.5),
            FunctionalModule(lambda x: x.transpose(0, 1)),
            torchvision.transforms.Grayscale(),
            FunctionalModule(lambda x: x.transpose(0, 1)),
            AdaptiveTimeMask(10, 25),
            torchvision.transforms.Normalize(0.421, 0.165),
        )
        self.train_audio_pipeline = torch.nn.Sequential(
            AdaptiveTimeMask(10, 25),
        )

    def __call__(self, samples: List):
        targets, target_lengths = _extract_labels(self.sp_model, samples)
Pingchuan Ma's avatar
Pingchuan Ma committed
103
        if self.args.modality == "audio":
Pingchuan Ma's avatar
Pingchuan Ma committed
104
105
106
107
            audios, audio_lengths = _extract_features(
                self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
            )
            return Batch(audios, audio_lengths, targets, target_lengths)
108
        if self.args.modality == "video":
Pingchuan Ma's avatar
Pingchuan Ma committed
109
110
111
112
            videos, video_lengths = _extract_features(
                self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
            )
            return Batch(videos, video_lengths, targets, target_lengths)
Pingchuan Ma's avatar
Pingchuan Ma committed
113
        if self.args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            audios, videos, audio_lengths, video_lengths = _extract_features(
                self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
            )
            return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)


class ValTransform:
    def __init__(self, sp_model_path: str, args):
        self.args = args
        self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
        self.valid_video_pipeline = torch.nn.Sequential(
            FunctionalModule(lambda x: x / 255.0),
            torchvision.transforms.CenterCrop(88),
            FunctionalModule(lambda x: x.transpose(0, 1)),
            torchvision.transforms.Grayscale(),
            FunctionalModule(lambda x: x.transpose(0, 1)),
            torchvision.transforms.Normalize(0.421, 0.165),
        )
        self.valid_audio_pipeline = torch.nn.Sequential(
            FunctionalModule(lambda x: x),
        )

    def __call__(self, samples: List):
        targets, target_lengths = _extract_labels(self.sp_model, samples)
Pingchuan Ma's avatar
Pingchuan Ma committed
138
        if self.args.modality == "audio":
Pingchuan Ma's avatar
Pingchuan Ma committed
139
140
141
142
            audios, audio_lengths = _extract_features(
                self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
            )
            return Batch(audios, audio_lengths, targets, target_lengths)
143
        if self.args.modality == "video":
Pingchuan Ma's avatar
Pingchuan Ma committed
144
145
146
147
            videos, video_lengths = _extract_features(
                self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
            )
            return Batch(videos, video_lengths, targets, target_lengths)
Pingchuan Ma's avatar
Pingchuan Ma committed
148
        if self.args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            audios, videos, audio_lengths, video_lengths = _extract_features(
                self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
            )
            return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)


class TestTransform:
    def __init__(self, sp_model_path: str, args):
        self.val_transforms = ValTransform(sp_model_path, args)

    def __call__(self, sample):
        return self.val_transforms([sample]), [sample]


def get_data_module(args, sp_model_path, max_frames=1800):
    train_transform = TrainTransform(sp_model_path=sp_model_path, args=args)
    val_transform = ValTransform(sp_model_path=sp_model_path, args=args)
    test_transform = TestTransform(sp_model_path=sp_model_path, args=args)
    return LRS3DataModule(
        args=args,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        max_frames=max_frames,
    )