unify_dataset.py 20.8 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import math
import logging
import random
import warnings

import numpy as np
import torch
import base64
from torchvision import transforms

from PIL import Image, ImageFile

from data import data_utils
from data.ofa_dataset import OFADataset
from utils.vision_helper import RandomAugment
import utils.transforms as T

from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.feature_transforms import *
from fairseq.data.audio.audio_utils import (
    convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank
)
from pathlib import Path
import soundfile as sf
import librosa
import torchaudio
from typing import List

from pypinyin import pinyin, Style
from utils.text2phone import Text2Phone
from g2p_en import G2p
g2p = G2p()

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)


def collate(
    samples,
    pad_idx,
    eos_idx,
    left_pad_source=False,
    left_pad_target=False,
    feature_only = True,
    mask = False,
    mask_prob = 0.0
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx,
            left_pad,
            move_eos_to_beginning,
        )

    def _collate_frames(
        frames: List[torch.Tensor]
    ):
        """
        Convert a list of 2D frames into a padded 3D tensor
        Args:
            frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
                length of i-th frame and f_dim is static dimension of features
        Returns:
            3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
        """
        max_len = max(frame.size(0) for frame in frames)
        out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
        for i, v in enumerate(frames):
            out[i, : v.size(0)] = v
        return out

    def _collate_constraint_masks(
        frames: List[torch.Tensor]
    ):
        """
        Convert a list of 2D frames into a padded 3D tensor
        Args:
            frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
                length of i-th frame and f_dim is static dimension of features
        Returns:
            3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
        """
        max_len = max(frame.size(0) for frame in frames)
        out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))).bool()
        for i, v in enumerate(frames):
            out[i, : v.size(0)] = v
        return out

    id = np.array([s["id"] for s in samples])
    src_tokens = merge("source", left_pad=left_pad_source)
    src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])

    fbank = None
    fbank_length = None
    fbank_masks = None
    if samples[0].get("fbank", None) is not None:
        fbank = _collate_frames([s["fbank"] for s in samples])
        fbank_length = torch.tensor([s["fbank"].size(0) for s in samples], dtype=torch.long)
        fbank_masks = torch.tensor([s["fbank_mask"] for s in samples])

    audio_code_masks = None
    if samples[0].get("audio_code_mask", None) is not None:
        audio_code_masks = torch.cat([sample['audio_code_mask'] for sample in samples])

    phone_items = None
    phone_lengths = None
    if samples[0].get("phone_item", None) is not None:
        phone_items = merge("phone_item", left_pad=left_pad_source)
        phone_lengths = torch.LongTensor([len(s["phone_item"]) for s in samples])
    phone_masks = None
    if samples[0].get("phone_mask", None) is not None:
        phone_masks = torch.cat([sample['phone_mask'] for sample in samples])

    prev_output_tokens = None
    target = None
    if samples[0].get("target", None) is not None:
        target = merge("target", left_pad=left_pad_target)
        tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
        ntokens = tgt_lengths.sum().item()

        if samples[0].get("prev_output_tokens", None) is not None:
            prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
    else:
        ntokens = src_lengths.sum().item()

    constraint_masks = None
    if samples[0].get("constraint_masks", None) is not None:
        constraint_masks = _collate_constraint_masks([s["constraint_masks"] for s in samples])

    batch = {
        "id": id,
        "nsentences": len(samples),
        "ntokens": ntokens,
        "net_input": {
            "src_tokens": src_tokens,
            "src_lengths": src_lengths,
            "fbank": fbank,
            "fbank_length": fbank_length,
            "fbank_masks": fbank_masks,
            "phone_items": phone_items,
            "phone_lengths": phone_lengths,
            "phone_masks": phone_masks,
            "audio_code_masks": audio_code_masks,
            "prev_output_tokens": prev_output_tokens,
            "encoder_features_only": feature_only,
            "mask": mask,
            "mask_prob": mask_prob
        },
        "target": target,
        "ctc_outputs": phone_items,
        "ctc_output_lengths": phone_lengths,
        "constraint_masks": constraint_masks
    }

    return batch

class UnifyDataset(OFADataset):
    def __init__(
        self,
        split,
        dataset,
        bpe,
        src_dict,
        tgt_dict=None,
        phone_dict=None,
        max_src_length=128,
        max_tgt_length=30,
        seed=7,
        code_dict_size=8192,
        audio_code_dict_size=30000,
        num_bins=1000,
        pure_text_dataset=None,
        pure_audio_dataset=None,
        speech_text_dataset=None,
        config_yaml_path=None,
        lang="zh",
        text2phone_path=None,
        train_stage=2,
        n_frames_per_step=1,
        sample_rate=16000,
    ):
        super().__init__(split, dataset, bpe, src_dict, tgt_dict)
        self.phone_dict = phone_dict
        self.max_src_length = max_src_length
        self.max_tgt_length = max_tgt_length
        self.seed = seed
        self.code_dict_size = code_dict_size
        self.audio_code_dict_size = audio_code_dict_size
        self.num_bins = num_bins

        self.pure_text_dataset = pure_text_dataset
        self.pure_audio_dataset = pure_audio_dataset
        self.speech_text_dataset = speech_text_dataset
        self.epoch = 0
        self.remove_pure_audio = self.pure_audio_dataset is None
        self.remove_pure_text = self.pure_text_dataset is None
    
        # config_yaml_path = Path(cfg.user_dir) / cfg.config_yaml)
        self.data_cfg = S2TDataConfig(Path(config_yaml_path))
        self.lang = lang
        self.train_stage= train_stage
        self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
            self.data_cfg.get_feature_transforms(split, split.startswith("train"))
        )
        self.n_frames_per_step = n_frames_per_step
        self.sample_rate = sample_rate
        self.blank_id = self.phone_dict.index("<blank>")
        self.phone_mask_idx = self.phone_dict.index("<mask>")
        self.text2phone_tokenizer = None
        if text2phone_path is not None:
            self.blank_id = self.phone_dict.index("<unk>")
            self.text2phone_tokenizer = Text2Phone(text2phone_path)

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    def process_pure_text(self, index):
        if self.train_stage == 1:
            speech_id, text = self.dataset[index]
        else:
            speech_id, text = self.pure_text_dataset[index]
        conf = torch.tensor([1.0])

        # fake input
        fbank = torch.zeros((8, self.data_cfg.input_feat_per_channel))
        fbank_mask = torch.tensor([False])
        #
        audio_code_mask = torch.tensor([False])

        if self.lang == "en":
            text = self.pre_caption(text, self.max_tgt_length)
        elif self.lang == "zh":
            text = self.pre_chinese(text, self.max_tgt_length)
        else:
            raise ValueError("lang must be en or zh")

        phone = self.to_phone(text, self.lang)
        phone_item = [int(x) for x in phone]
        phone_item = torch.tensor(phone_item)
        phone_item = self.add_noise_to_phone(phone_item, 0.3)

        phone_mask = torch.tensor([True])

        target = text

        src_item = self.encode_text(" what does the phone say?")
        tgt_item = self.encode_text(" {}".format(target))

        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        target_item = torch.cat([tgt_item, self.eos_item])
        prev_output_item = torch.cat([self.bos_item, tgt_item])

        constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) < len(
            self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
                                        range(len(target_item))])

        example = {
            "id": speech_id,
            "source": src_item,
            "fbank": fbank,
            "fbank_mask": fbank_mask,
            "phone_item": phone_item,
            "phone_mask": phone_mask,
            "audio_code_mask": audio_code_mask,
            "target": target_item,
            "prev_output_tokens": prev_output_item,
            "conf": conf,
            "constraint_masks": constraint_masks
        }
        return [example]

    def process_pure_audio(self, index):
        if self.train_stage == 2:
            speech_id, wav_data, code = self.dataset[index]
        else:
            speech_id, wav_data, code = self.pure_audio_dataset[index]

        # fake input
        phone_item = [6, 6, 6]
        phone_item = torch.tensor(phone_item)
        phone_mask = torch.tensor([False])

        # speed
        if self.split == "train":
            speed = random.choice([0.9, 1.0, 1.1])
        else:
            speed = 1.0
        wav, sr = sf.read(wav_data)
        # spec_augmentation
        fbank = self.prepare_fbank(torch.tensor([wav], dtype=torch.float32), sr, speed)

        fbank_mask = torch.tensor([True])
        audio_code_mask = torch.tensor([True])

        if code is not None and len(code) > 0:
            text = torch.LongTensor([int(num) for num in code.strip().split(",")])
            tgt_item = text + len(self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins
        else:
            # fake
            text = torch.LongTensor([1, 2, 3])
            tgt_item = text

        conf = torch.tensor([1.0])

        # useless
        src_item = self.encode_text(' what does the audio say?')

        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        target_item = torch.cat([tgt_item, self.eos_item])
        prev_output_item = torch.cat([self.bos_item, tgt_item])

        constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) >= len(
            self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
                                        range(len(target_item))])
        constraint_masks[:, :3] = True

        example = {
            "id": speech_id,
            "source": src_item,
            "fbank": fbank,
            "fbank_mask": fbank_mask,
            "phone_item": phone_item,
            "phone_mask": phone_mask,
            "audio_code_mask": audio_code_mask,
            "target": target_item,
            "prev_output_tokens": prev_output_item,
            "conf": conf,
        }
        return [example]

    def process_speech_text_pair(self, index, dataset=None):
        if dataset is not None:
            speech_id, wav_data, text = dataset[index]
        elif self.train_stage == 2:
            speech_id, wav_data, text = self.speech_text_dataset[index]
        else:
            speech_id, wav_data, text = self.dataset[index]
       
        conf = torch.tensor([1.0])
        audio_code_mask = torch.tensor([False])

        # speed
        if self.split == "train":
            speed = random.choice([0.9, 1.0, 1.1])
        else:
            speed = 1.0
        # wav, sr = sf.read(wav_data)
        wav, sr = librosa.load(wav_data, self.sample_rate)
        # spec_augmentation
        fbank = self.prepare_fbank(torch.tensor([wav], dtype=torch.float32), sr, speed, speech_id)

        fbank_mask = torch.tensor([True])

        if self.lang == "en":
            text = self.pre_caption(text, self.max_tgt_length)
        elif self.lang == "zh":
            text = self.pre_chinese(text, self.max_tgt_length)
        else:
            raise ValueError("lang must be en or zh")
        target = text

        phone_item = self.to_phone(text, self.lang)-3
        phone_mask = torch.tensor([False])

        src_item = self.encode_text(" what does the audio say?")
        tgt_item = self.encode_text(" {}".format(target))

        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        target_item = torch.cat([tgt_item, self.eos_item])
        prev_output_item = torch.cat([self.bos_item, tgt_item])

        constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) < len(
            self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
                                        range(len(target_item))])

        example = {
            "id": speech_id,
            "source": src_item,
            "fbank": fbank,
            "fbank_mask": fbank_mask,
            "phone_item": phone_item,
            "phone_mask": phone_mask,
            "audio_code_mask": audio_code_mask,
            "target": target_item,
            "prev_output_tokens": prev_output_item,
            "conf": conf,
            "constraint_masks": constraint_masks
        }
        return [example]

    def __getitem__(self, index):

        with data_utils.numpy_seed(self.seed, self.epoch):
            if self.train_stage == 1:
                extra_samples = []
                if self.dataset is not None:
                    extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
                return extra_samples, [], []
            elif self.train_stage == 2:
                pair_examples = []
                audio_examples = []
                extra_samples = []
                if self.split == 'train':
                    if self.dataset is not None:
                        audio_examples += self.process_pure_audio(index) if not self.remove_pure_audio else []
                    if self.speech_text_dataset is not None and self.dataset.data_cnt % 4 == 0:
                        pair_examples += self.process_speech_text_pair(index)
                    if self.pure_text_dataset is not None and self.dataset.data_cnt % 2 == 0:
                        extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
                else:
                    if self.dataset is not None:
                        pair_examples += self.process_speech_text_pair(index, self.dataset)
                return pair_examples, extra_samples, audio_examples
            else:
                pair_examples = []
                extra_samples = []
                if self.split == 'train':
                    if self.dataset is not None:
                        pair_examples += self.process_speech_text_pair(index)
                    if self.pure_text_dataset is not None and self.dataset.data_cnt % 2 == 0:
                        extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
                else:
                    if self.dataset is not None:
                        pair_examples += self.process_speech_text_pair(index, self.dataset)
                return pair_examples, extra_samples, []

    def to_phone(self, text, lang):

        if lang == "en":
            phone_result = None
            try:
                phone_result = " ".join(p for p in g2p(text))
            except Exception as e:
                print(e, text)
            return self.encode_phone(phone_result)

        elif lang == "zh":
            if self.text2phone_tokenizer is not None:
                final_phone = self.text2phone_tokenizer.trans(text)
                return self.encode_phone(final_phone)
            else:
                shengmu = pinyin(text, style=Style.INITIALS, strict=False)
                yunmu = pinyin(text, style=Style.FINALS_TONE3, strict=False)
                assert len(shengmu) == len(yunmu)
                final_phone = []
                for s, y in zip(shengmu, yunmu):
                    if s[0] == y[0] or s[0] == "":
                        final_phone.append(y[0])
                    else:
                        final_phone.append(s[0] + " " + y[0])
                return self.encode_phone(" ".join(final_phone))

    def encode_phone(self, phone_item):
        tokens = self.phone_dict.encode_line(
            line=phone_item, add_if_not_exist=False, append_eos=False).long()
        return tokens

    def add_noise_to_phone(self, phone, p, random_p=0.1):
        num_to_mask = int(math.ceil(phone.size(0) * p))
        indices = torch.randperm(phone.size(0))[:num_to_mask]
        mask_random = torch.FloatTensor(num_to_mask).uniform_() < random_p
        phone[indices] = self.phone_mask_idx
        if mask_random.sum() > 0:
            phone[indices[mask_random]] = torch.randint(
                4, self.phone_mask_idx, size=(mask_random.sum(),)
            )
        return phone

    def prepare_fbank(self, waveform, sample_rate, speed, speech_id=None):
        waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform, sample_rate,
                [['speed', str(speed)], ['rate', str(sample_rate)]])
        _waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True, normalize_volume=True)
        # Kaldi compliance: 16-bit signed integers
        _waveform = _waveform * (2 ** 15)
        _waveform = _waveform.numpy()
        fbank = _get_kaldi_fbank(_waveform, sample_rate, 80)
        if fbank is None:
            fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80)
        if fbank is None:
            raise ImportError(
                "Please install pyKaldi or torchaudio to enable fbank feature extraction"
            )
        if self.feature_transforms is not None:
            fbank = self.feature_transforms(fbank)
        fbank = torch.from_numpy(fbank).float()
        fbank = self.pack_frames(fbank)
        return fbank

    def pack_frames(self, feature: torch.Tensor):
        if self.n_frames_per_step == 1:
            return feature
        n_packed_frames = feature.shape[0] // self.n_frames_per_step
        feature = feature[: self.n_frames_per_step * n_packed_frames]
        return feature.reshape(n_packed_frames, -1)

    def collater(self, samples, pad_to_length=None):
        """Merge samples of different tasks to form two mini-batches.
        Args:
            samples (List[Tuple]): samples to collate
        Returns:
            Tuple[dict]: two mini-batch containing the data of different tasks
        """

        samples_v1 = []   # containing phone-text pairs at stage-1, containing speech-text pairs at stage-2
        samples_v2 = []   # containing phone-text pairs
        samples_v3 = []   # containing pure_audio_pairs
        for sample_tuple in samples:
            samples_v1 += sample_tuple[0]
            samples_v2 += sample_tuple[1]
            if len(sample_tuple) > 2:
                samples_v3 += sample_tuple[2]

        if samples_v1 == []:
            if self.train_stage == 1:
                samples_v1 += self.process_pure_text(0)
            else:
                samples_v1 += self.process_speech_text_pair(0)

        mask = False
        mask_prob = None
        if self.split == "train" and self.train_stage != 1:
            mask = True
            mask_prob = 0.3

        res_v1 = collate(
            samples_v1,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            feature_only=True,
            mask=mask,
            mask_prob=mask_prob
        )

        if self.split == 'train' and self.train_stage != 1:
            if samples_v2 == []:
                if self.pure_text_dataset is not None:
                    samples_v2 += self.process_pure_text(0) if not self.remove_pure_text else []
            res_v2 = collate(
                samples_v2,
                pad_idx=self.src_dict.pad(),
                eos_idx=self.eos
            )
            if samples_v3 == []:
                if self.pure_audio_dataset is not None:
                    samples_v3 += self.process_pure_audio(0) if not self.remove_pure_audio else []
                else:
                    return res_v1, res_v2
            res_v3 = collate(
                samples_v3,
                pad_idx=self.src_dict.pad(),
                eos_idx=self.eos,
                feature_only=False,
                mask=True
            )
            return res_v1, res_v2, res_v3
        else:
            return res_v1