Commit 6e2bd794 authored by Alexei Baevski's avatar Alexei Baevski Committed by Facebook Github Bot
Browse files

wav2vec everstore support

Summary: changes for internal support

Differential Revision: D16646887

fbshipit-source-id: ac5bf6c32901819726249422324eae32a0a6e148
parent d4c9136c
...@@ -9,7 +9,7 @@ from .fairseq_dataset import FairseqDataset ...@@ -9,7 +9,7 @@ from .fairseq_dataset import FairseqDataset
from .base_wrapper_dataset import BaseWrapperDataset from .base_wrapper_dataset import BaseWrapperDataset
from .audio.raw_audio_dataset import RawAudioDataset from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset from .concat_sentences_dataset import ConcatSentencesDataset
...@@ -78,9 +78,9 @@ __all__ = [ ...@@ -78,9 +78,9 @@ __all__ = [
'PadDataset', 'PadDataset',
'PrependDataset', 'PrependDataset',
'PrependTokenDataset', 'PrependTokenDataset',
'RawAudioDataset',
'RawLabelDataset',
'ReplaceDataset', 'ReplaceDataset',
'FileAudioDataset',
"RawLabelDataset",
'RightPadDataset', 'RightPadDataset',
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
'ShardedDataset', 'ShardedDataset',
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import os import os
import numpy as np import numpy as np
import sys import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -14,61 +15,71 @@ from .. import FairseqDataset ...@@ -14,61 +15,71 @@ from .. import FairseqDataset
class RawAudioDataset(FairseqDataset): class RawAudioDataset(FairseqDataset):
def __init__(
def __init__(self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=None, self,
shuffle=True): sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
):
super().__init__() super().__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.fnames = []
self.sizes = [] self.sizes = []
self.max_sample_size = max_sample_size if max_sample_size is not None else sys.maxsize self.max_sample_size = (
self.min_sample_size = min_sample_size if min_sample_size is not None else self.max_sample_size max_sample_size if max_sample_size is not None else sys.maxsize
)
with open(manifest_path, 'r') as f: self.min_sample_size = (
self.root_dir = f.readline().strip() min_sample_size if min_sample_size is not None else self.max_sample_size
for line in f: )
items = line.strip().split('\t') self.min_length = min_length
assert len(items) == 2, line
self.fnames.append(items[0])
self.sizes.append(int(items[1]))
self.shuffle = shuffle self.shuffle = shuffle
def __getitem__(self, index): def __getitem__(self, index):
fname = os.path.join(self.root_dir, self.fnames[index]) raise NotImplementedError()
import soundfile as sf
wav, curr_sample_rate = sf.read(fname) def __len__(self):
feats = torch.from_numpy(wav).float() return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
def resample(x, factor):
return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze()
if feats.dim() == 2: if feats.dim() == 2:
feats = feats.mean(-1) feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate: if curr_sample_rate != self.sample_rate:
factor = self.sample_rate / curr_sample_rate factor = self.sample_rate / curr_sample_rate
feats = self.resample(feats, factor) feats = resample(feats, factor)
assert feats.dim() == 1, feats.dim() assert feats.dim() == 1, feats.dim()
return feats
return { def crop_to_max_size(self, wav, target_size):
'id': index, size = len(wav)
'source': feats, diff = size - target_size
} if diff <= 0:
return wav
def resample(self, x, factor): start = np.random.randint(0, diff + 1)
return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze() end = size - diff + start
return wav[start:end]
def __len__(self):
return len(self.fnames)
def collater(self, samples): def collater(self, samples):
samples = [
s for s in samples if s["source"] is not None and len(s["source"]) > 0
]
if len(samples) == 0: if len(samples) == 0:
return {} return {}
sources = [s['source'] for s in samples] sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources] sizes = [len(s) for s in sources]
target_size = min(min(sizes), self.max_sample_size) target_size = min(min(sizes), self.max_sample_size)
if target_size < self.min_length:
return {}
if self.min_sample_size < target_size: if self.min_sample_size < target_size:
target_size = np.random.randint(self.min_sample_size, target_size + 1) target_size = np.random.randint(self.min_sample_size, target_size + 1)
...@@ -79,32 +90,13 @@ class RawAudioDataset(FairseqDataset): ...@@ -79,32 +90,13 @@ class RawAudioDataset(FairseqDataset):
if diff == 0: if diff == 0:
collated_sources[i] = source collated_sources[i] = source
else: else:
start = np.random.randint(0, diff + 1) collated_sources[i] = self.crop_to_max_size(source, target_size)
end = size - diff + start
collated_sources[i] = source[start:end]
return { return {
'id': torch.LongTensor([s['id'] for s in samples]), "id": torch.LongTensor([s["id"] for s in samples]),
'net_input': { "net_input": {"source": collated_sources},
'source': collated_sources,
},
} }
def get_dummy_batch(
self, num_tokens, max_positions, src_len=2048, tgt_len=128,
):
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
src_len = min(src_len, max_positions)
bsz = num_tokens // src_len
return self.collater([
{
'id': i,
'source': torch.rand(src_len),
}
for i in range(bsz)
])
def num_tokens(self, index): def num_tokens(self, index):
return self.size(index) return self.size(index)
...@@ -124,3 +116,41 @@ class RawAudioDataset(FairseqDataset): ...@@ -124,3 +116,41 @@ class RawAudioDataset(FairseqDataset):
order.append(self.sizes) order.append(self.sizes)
return np.lexsort(order) return np.lexsort(order)
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
min_length=min_length,
)
self.fnames = []
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for line in f:
items = line.strip().split("\t")
assert len(items) == 2, line
self.fnames.append(items[0])
self.sizes.append(int(items[1]))
def __getitem__(self, index):
import soundfile as sf
fname = os.path.join(self.root_dir, self.fnames[index])
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os import os
from fairseq.data import RawAudioDataset from fairseq.data import FileAudioDataset
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -46,10 +46,10 @@ class AudioPretrainingTask(FairseqTask): ...@@ -46,10 +46,10 @@ class AudioPretrainingTask(FairseqTask):
""" """
manifest = os.path.join(self.args.data, '{}.tsv'.format(split)) manifest = os.path.join(self.args.data, '{}.tsv'.format(split))
self.datasets[split] = RawAudioDataset(manifest, self.datasets[split] = FileAudioDataset(manifest,
sample_rate=self.args.sample_rate, sample_rate=self.args.sample_rate,
max_sample_size=self.args.max_sample_size, max_sample_size=self.args.max_sample_size,
min_sample_size=self.args.min_sample_size) min_sample_size=self.args.min_sample_size)
@property @property
def target_dictionary(self): def target_dictionary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment