Commit 392fce8a authored by alexeib's avatar alexeib Committed by Facebook Github Bot
Browse files

wav2vec model (#654)

Summary:
Merging wav2vec to master. Includes renames (Cpc -> wav2vec) and some light example files.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/654

Differential Revision: D15913409

Pulled By: alexeib

fbshipit-source-id: f723e6f211706cd9431c7d76dc12c4e80c9cfc80
parent bd710e75
......@@ -9,6 +9,7 @@ of various sequence-to-sequence models, including:
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
......@@ -82,6 +83,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available
We also have more detailed READMEs to reproduce results from specific papers:
- [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md)
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
......
# wav2vec
Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862).
## Training a new model with the CLI tools
Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)
### Prepare training data manifest:
```
$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
```
### Train a wav2vec model:
```
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
--arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \
--max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test
```
### Extract embeddings from the downstream task data:
```
$ PYTHONPATH /path/to/fairseq python scripts/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \
--model /model/path/checkpoint_best.pt --split train valid test
```
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('binary_cross_entropy')
class BinaryCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output, expand_steps=False).float()
if hasattr(model, 'get_target_weights'):
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
else:
weights = 1.
loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False)
loss = loss * weights
if reduce:
loss = loss.sum()
sample_size = target.numel()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample_size,
'nsentences': logits.size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
\ No newline at end of file
......@@ -13,7 +13,7 @@ class FairseqCriterion(_Loss):
def __init__(self, args, task):
super().__init__()
self.args = args
self.padding_idx = task.target_dictionary.pad()
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
@staticmethod
def add_args(parser):
......
......@@ -10,6 +10,7 @@ from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .fairseq_dataset import FairseqDataset
from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .block_pair_dataset import BlockPairDataset
from .concat_dataset import ConcatDataset
......@@ -51,6 +52,7 @@ __all__ = [
'MMapIndexedDataset',
'MonolingualDataset',
'NoisingDataset',
'RawAudioDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
'TokenBlockDataset',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import numpy as np
import sys
import torch
import torch.nn.functional as F
from .. import FairseqDataset
class RawAudioDataset(FairseqDataset):
def __init__(self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=None,
shuffle=True):
super().__init__()
self.sample_rate = sample_rate
self.fnames = []
self.sizes = []
self.max_sample_size = max_sample_size if max_sample_size is not None else sys.maxsize
self.min_sample_size = min_sample_size if min_sample_size is not None else self.max_sample_size
with open(manifest_path, 'r') as f:
self.root_dir = f.readline().strip()
for line in f:
items = line.strip().split('\t')
assert len(items) == 2, line
self.fnames.append(items[0])
self.sizes.append(int(items[1]))
self.shuffle = shuffle
def __getitem__(self, index):
fname = os.path.join(self.root_dir, self.fnames[index])
import soundfile as sf
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
factor = self.sample_rate / curr_sample_rate
feats = self.resample(feats, factor)
assert feats.dim() == 1, feats.dim()
return {
'id': index,
'source': feats,
}
def resample(self, x, factor):
return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze()
def __len__(self):
return len(self.fnames)
def collater(self, samples):
if len(samples) == 0:
return {}
sources = [s['source'] for s in samples]
sizes = [len(s) for s in sources]
target_size = min(min(sizes), self.max_sample_size)
if self.min_sample_size < target_size:
target_size = np.random.randint(self.min_sample_size, target_size + 1)
collated_sources = sources[0].new(len(sources), target_size)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
assert diff >= 0
if diff == 0:
collated_sources[i] = source
else:
start = np.random.randint(0, diff + 1)
end = size - diff + start
collated_sources[i] = source[start:end]
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'net_input': {
'source': collated_sources,
},
}
def get_dummy_batch(
self, num_tokens, max_positions, src_len=2048, tgt_len=128,
):
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
src_len = min(src_len, max_positions)
bsz = num_tokens // src_len
return self.collater([
{
'id': i,
'source': torch.rand(src_len),
}
for i in range(bsz)
])
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import (
BaseFairseqModel, register_model, register_model_architecture
)
@register_model('wav2vec')
class Wav2VecModel(BaseFairseqModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--prediction-steps', type=int, metavar='N', help='number of steps ahead to predict')
parser.add_argument('--sample-distance', type=int, metavar='N',
help='sample distance from target. does not work properly with cross-sampling')
parser.add_argument('--cross-sample-negatives', action='store_true',
help='whether to sample negatives across examples in the same batch')
parser.add_argument('--num-negatives', type=int, metavar='N', help='number of negative examples')
parser.add_argument('--conv-feature-layers', type=str, metavar='EXPR',
help='convolutional feature extraction layers [(dim, kernel_size, stride), ...]')
parser.add_argument('--conv-aggregator-layers', type=str, metavar='EXPR',
help='convolutional feature extraction layers [(dim, kernel_size, stride), ...]')
parser.add_argument('--dropout', type=float, metavar='D', help='dropout to apply within the model')
parser.add_argument('--dropout-features', type=float, metavar='D', help='dropout to apply to the features')
parser.add_argument('--dropout-agg', type=float, metavar='D', help='dropout to apply after aggregation step')
parser.add_argument('--encoder', type=str, choices=['cnn'], help='type of encoder to use')
parser.add_argument('--aggregator', type=str, choices=['cnn', 'gru'],
help='type of aggregator to use')
parser.add_argument('--gru-dim', type=int, metavar='N', help='GRU dimensionality')
parser.add_argument('--no-conv-bias', action='store_true',
help='if set, does not learn bias for conv layers')
parser.add_argument('--agg-zero-pad', action='store_true',
help='if set, zero pads in aggregator instead of repl pad')
parser.add_argument('--skip-connections-feat', action='store_true',
help='if set, adds skip connections to the feature extractor')
parser.add_argument('--skip-connections-agg', action='store_true',
help='if set, adds skip connections to the aggregator')
parser.add_argument('--residual-scale', type=float, metavar='D',
help='scales residual by sqrt(value)')
parser.add_argument('--log-compression', action='store_true',
help='if set, adds a log compression to feature extractor')
parser.add_argument('--balanced-classes', action='store_true',
help='if set, loss is scaled to balance for number of negatives')
parser.add_argument('--project-features', choices=['none', 'same', 'new'],
help='if not none, features are projected using the (same or new) aggregator')
parser.add_argument('--non-affine-group-norm', action='store_true',
help='if set, group norm is not affine')
parser.add_argument('--offset', help='if set, introduces an offset from target to predictions. '
'if set to "auto", it is computed automatically from the receptive field')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_wav2vec_architecture(args)
model = Wav2VecModel(args)
print(model)
return model
def __init__(self, args):
super().__init__()
self.prediction_steps = args.prediction_steps
offset = args.offset
if args.encoder == 'cnn':
feature_enc_layers = eval(args.conv_feature_layers)
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.,
log_compression=args.log_compression,
skip_connections=args.skip_connections_feat,
residual_scale=args.residual_scale,
non_affine_group_norm=args.non_affine_group_norm,
)
embed = feature_enc_layers[-1][0]
else:
raise Exception('unknown encoder type ' + args.encoder)
if args.offset == 'auto':
assert args.encoder == 'cnn'
jin = 0
rin = 0
for _, k, stride in feature_enc_layers:
if rin == 0:
rin = k
rin = rin + (k - 1) * jin
if jin == 0:
jin = stride
else:
jin *= stride
offset = math.ceil(rin / jin)
offset = int(offset)
def make_aggregator():
if args.aggregator == 'cnn':
agg_layers = eval(args.conv_aggregator_layers)
agg_dim = agg_layers[-1][0]
feature_aggregator = ConvAggegator(
conv_layers=agg_layers,
embed=embed,
dropout=args.dropout,
skip_connections=args.skip_connections_agg,
residual_scale=args.residual_scale,
non_affine_group_norm=args.non_affine_group_norm,
conv_bias=not args.no_conv_bias,
zero_pad=args.agg_zero_pad,
)
elif args.aggregator == 'gru':
agg_dim = args.gru_dim
feature_aggregator = nn.Sequential(
TransposeLast(),
nn.GRU(
input_size=embed,
hidden_size=agg_dim,
num_layers=1,
dropout=args.dropout,
),
TransposeLast(deconstruct_idx=0),
)
else:
raise Exception('unknown aggregator type ' + args.aggregator)
return feature_aggregator, agg_dim
self.feature_aggregator, agg_dim = make_aggregator()
self.wav2vec_predictions = Wav2VecPredictionsModel(
in_dim=agg_dim,
out_dim=embed,
prediction_steps=args.prediction_steps,
n_negatives=args.num_negatives,
cross_sample_negatives=args.cross_sample_negatives,
sample_distance=args.sample_distance,
dropout=args.dropout,
offset=offset,
balanced_classes=args.balanced_classes,
)
self.dropout_feats = nn.Dropout(p=args.dropout_features)
self.dropout_agg = nn.Dropout(p=args.dropout_agg)
if args.project_features == 'none':
self.project_features = None
elif args.project_features == 'same':
self.project_features = self.feature_aggregator
elif args.project_features == 'new':
self.project_features, _ = make_aggregator()
def forward(self, source):
result = {}
features = self.feature_extractor(source)
x = self.dropout_feats(features)
x = self.feature_aggregator(x)
x = self.dropout_agg(x)
if self.project_features is not None:
features = self.project_features(features)
x, targets = self.wav2vec_predictions(x, features)
result['cpc_logits'] = x
result['cpc_targets'] = targets
return result
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def max_positions(self):
"""Maximum length supported by the model."""
return sys.maxsize
def get_logits(self, net_output):
logits = net_output['cpc_logits']
return logits
def get_targets(self, sample, net_output, expand_steps=True):
t = net_output['cpc_targets']
return t.contiguous()
def get_target_weights(self, targets, net_output):
targets = net_output['cpc_targets']
if isinstance(targets, tuple) and targets[-1] is not None:
return targets[-1]
return 1.
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(), self.num_groups, self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None, self.eps)
return output.type_as(input)
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None, self.eps)
return output.type_as(input)
def norm_block(is_layer_norm, dim, affine=True):
if is_layer_norm:
mod = nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=affine),
TransposeLast(),
)
else:
mod = Fp32GroupNorm(1, dim, affine=affine)
return mod
class ConvFeatureExtractionModel(nn.Module):
def __init__(self, conv_layers, dropout, log_compression, skip_connections, residual_scale, non_affine_group_norm):
super().__init__()
def block(n_in, n_out, k, stride):
return nn.Sequential(
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
nn.Dropout(p=dropout),
norm_block(is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm),
nn.ReLU(),
)
in_d = 1
self.conv_layers = nn.ModuleList()
for i, (dim, k, stride) in enumerate(conv_layers):
self.conv_layers.append(
block(in_d, dim, k, stride))
in_d = dim
self.log_compression = log_compression
self.skip_connections = skip_connections
self.residual_scale = math.sqrt(residual_scale)
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
residual = x
x = conv(x)
if self.skip_connections and x.size(1) == residual.size(1):
tsz = x.size(2)
r_tsz = residual.size(2)
residual = residual[..., ::r_tsz // tsz][..., :tsz]
x = (x + residual) * self.residual_scale
if self.log_compression:
x = x.abs()
x = x + 1
x = x.log()
return x
class ZeroPad1d(nn.Module):
def __init__(self, pad_left, pad_right):
super().__init__()
self.pad_left = pad_left
self.pad_right = pad_right
def forward(self, x):
return F.pad(x, (self.pad_left, self.pad_right))
class ConvAggegator(nn.Module):
def __init__(self, conv_layers, embed, dropout, skip_connections, residual_scale, non_affine_group_norm, conv_bias,
zero_pad):
super().__init__()
def block(n_in, n_out, k, stride):
# padding dims only really make sense for stride = 1
ka = k // 2
kb = ka - 1 if k % 2 == 0 else ka
pad = ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0))
return nn.Sequential(
pad,
nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias),
nn.Dropout(p=dropout),
norm_block(False, n_out, affine=not non_affine_group_norm),
nn.ReLU(),
)
in_d = embed
self.conv_layers = nn.ModuleList()
self.residual_proj = nn.ModuleList()
for i, (dim, k, stride) in enumerate(conv_layers):
if in_d != dim and skip_connections:
self.residual_proj.append(
nn.Conv1d(in_d, dim, 1, bias=False),
)
else:
self.residual_proj.append(None)
self.conv_layers.append(
block(in_d, dim, k, stride))
in_d = dim
self.conv_layers = nn.Sequential(*self.conv_layers)
self.skip_connections = skip_connections
self.residual_scale = math.sqrt(residual_scale)
def forward(self, x):
for rproj, conv in zip(self.residual_proj, self.conv_layers):
residual = x
x = conv(x)
if self.skip_connections:
if rproj != None:
residual = rproj(residual)
x = (x + residual) * self.residual_scale
return x
class Wav2VecPredictionsModel(nn.Module):
def __init__(self, in_dim, out_dim, prediction_steps, n_negatives, cross_sample_negatives, sample_distance,
dropout, offset, balanced_classes):
super().__init__()
self.n_negatives = n_negatives
self.cross_sample_negatives = cross_sample_negatives
self.sample_distance = sample_distance
self.project_to_steps = nn.ConvTranspose2d(in_dim, out_dim, (1, prediction_steps))
self.dropout = nn.Dropout(p=dropout)
self.offset = offset
self.balanced_classes = balanced_classes
def sample_negatives(self, y):
bsz, fsz, tsz = y.shape
y = y.transpose(0, 1) # BCT -> CBT
y = y.contiguous().view(fsz, -1) # CBT => C(BxT)
if self.cross_sample_negatives:
high = tsz * bsz
assert self.sample_distance is None, 'sample distance is not supported with cross sampling'
else:
high = tsz if self.sample_distance is None else min(tsz, self.sample_distance)
neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz))
if self.sample_distance is not None and self.sample_distance < tsz:
neg_idxs += torch.cat(
[torch.arange(start=1, end=tsz - self.sample_distance, device=neg_idxs.device, dtype=neg_idxs.dtype),
torch.arange(start=tsz - self.sample_distance, end=tsz - self.sample_distance * 2 - 1, step=-1,
device=neg_idxs.device, dtype=neg_idxs.dtype)])
if not self.cross_sample_negatives:
for i in range(1, bsz):
neg_idxs[i] += i * high
negs = y[..., neg_idxs.view(-1)]
negs = negs.view(fsz, bsz, self.n_negatives, tsz).permute(2, 1, 0, 3) # to NxBxCxT
return negs
def forward(self, x, y):
negatives = self.sample_negatives(y)
y = y.unsqueeze(0)
targets = torch.cat([y, negatives], dim=0)
x = x.unsqueeze(-1)
x = self.project_to_steps(x) # BxCxTxS
x = self.dropout(x)
x = x.unsqueeze(0).expand(targets.size(0), -1, -1, -1, -1)
copies, bsz, dim, tsz, steps = x.shape
steps = min(steps, tsz - self.offset)
predictions = x.new(bsz * copies * (tsz - self.offset + 1) * steps - ((steps + 1) * steps // 2) * copies * bsz)
labels = torch.zeros_like(predictions)
weights = torch.full_like(labels, 1 / self.n_negatives) if self.balanced_classes else None
start = end = 0
for i in range(steps):
offset = i + self.offset
end = start + (tsz - offset) * bsz * copies
pos_num = (end - start) // copies
predictions[start:end] = (x[..., :-offset, i] * targets[..., offset:]).sum(dim=2).flatten()
labels[start:start + pos_num] = 1.
if weights is not None:
weights[start:start + pos_num] = 1.
start = end
assert end == predictions.numel(), '{} != {}'.format(end, predictions.numel())
if weights is not None:
labels = (labels, weights)
return predictions, labels
@register_model_architecture('wav2vec', 'wav2vec')
def base_wav2vec_architecture(args):
conv_feature_layers = '[(512, 10, 5)]'
conv_feature_layers += ' + [(512, 8, 4)]'
conv_feature_layers += ' + [(512, 4, 2)] * 3'
args.conv_feature_layers = getattr(args, 'conv_feature_layers', conv_feature_layers)
args.conv_aggregator_layers = getattr(args, 'conv_aggregator_layers', '[(512, 3, 1)] * 9')
args.prediction_steps = getattr(args, 'prediction_steps', 12)
args.num_negatives = getattr(args, 'num_negatives', 1)
args.sample_distance = getattr(args, 'sample_distance', None)
args.cross_sample_negatives = getattr(args, 'cross_sample_negatives', False)
args.dropout = getattr(args, 'dropout', 0.)
args.dropout_features = getattr(args, 'dropout_features', 0.)
args.dropout_agg = getattr(args, 'dropout_agg', 0.)
args.encoder = getattr(args, 'encoder', 'cnn')
args.aggregator = getattr(args, 'aggregator', 'cnn')
args.skip_connections_feat = getattr(args, 'skip_connections_feat', False)
args.skip_connections_agg = getattr(args, 'skip_connections_agg', False)
args.residual_scale = getattr(args, 'residual_scale', 0.5)
args.gru_dim = getattr(args, 'gru_dim', 512)
args.no_conv_bias = getattr(args, 'no_conv_bias', False)
args.agg_zero_pad = getattr(args, 'agg_zero_pad', False)
args.log_compression = getattr(args, 'log_compression', False)
args.balanced_classes = getattr(args, 'balanced_classes', False)
args.project_features = getattr(args, 'project_features', 'none')
args.non_affine_group_norm = getattr(args, 'non_affine_group_norm', False)
args.offset = getattr(args, 'offset', 'auto')
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import RawAudioDataset
from . import FairseqTask, register_task
@register_task('audio_pretraining')
class AudioPretrainingTask(FairseqTask):
"""
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory')
parser.add_argument('--sample-rate', default=16000, type=int,
help='target sample rate. audio files will be up/down sampled to this rate')
parser.add_argument('--max-sample-size', default=None, type=int,
help='max sample size to crop to for batching. default = min sample length')
parser.add_argument('--min-sample-size', default=None, type=int,
help='min sample size to crop to for batching. default = same as --max-sample-size')
def __init__(self, args):
super().__init__(args)
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args)
def load_dataset(self, split, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
manifest = os.path.join(self.args.data, '{}.tsv'.format(split))
self.datasets[split] = RawAudioDataset(manifest,
sample_rate=self.args.sample_rate,
max_sample_size=self.args.max_sample_size,
min_sample_size=self.args.min_sample_size)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return None
\ No newline at end of file
......@@ -233,11 +233,11 @@ class Trainer(object):
# forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
sample = self._prepare_sample(sample, self.args.fp16)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
sample = self._prepare_sample(self._dummy_batch, self.args.fp16)
ignore_grad = True
else:
ignore_grad = False
......@@ -381,9 +381,9 @@ class Trainer(object):
self.model.eval()
self.criterion.eval()
sample = self._prepare_sample(sample)
sample = self._prepare_sample(sample, self.args.fp16)
if sample is None:
sample = self._prepare_sample(self._dummy_batch)
sample = self._prepare_sample(self._dummy_batch, self.args.fp16)
ignore_results = True
else:
ignore_results = False
......@@ -488,12 +488,19 @@ class Trainer(object):
self._num_updates = num_updates
self.lr_step_update()
def _prepare_sample(self, sample):
def _prepare_sample(self, sample, fp16):
if sample is None or len(sample) == 0:
return None
if self.cuda:
sample = utils.move_to_cuda(sample)
return sample
def apply_half(t):
if t.dtype is torch.float32:
return t.half()
return t
return utils.apply(apply_half, sample) if fp16 else sample
def _set_seed(self):
# Set seed based on args.seed and the update number so that we get
......
......@@ -31,24 +31,26 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
)
def move_to_cuda(sample):
def apply(f, sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
if torch.is_tensor(sample):
return f(sample)
elif isinstance(sample, dict):
return {
key: _move_to_cuda(value)
for key, value in maybe_tensor.items()
key: apply(f, value)
for key, value in sample.items()
}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(sample, list):
return [apply(f, x) for x in sample]
else:
return maybe_tensor
return sample
return _move_to_cuda(sample)
def move_to_cuda(sample):
def _move_to_cuda(tensor):
return tensor.cuda()
return apply(_move_to_cuda, sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
......
""" Helper script to pre-compute embeddings for a wav2letter++ dataset
"""
import glob, os
import tqdm
from shutil import copy
import soundfile as sf
import h5py
import numpy as np
import torch
from torch import nn
from fairseq.models.wav2vec import Wav2VecModel
import argparse
def read_audio(fname):
""" Load an audio file and return PCM along with the sample rate """
wav, sr = sf.read(fname)
assert sr == 16e3
return wav, 16e3
class PretrainedWav2VecModel(nn.Module):
def __init__(self, fname):
super().__init__()
checkpoint = torch.load(fname)
self.args = checkpoint["args"]
model = Wav2VecModel.build_model(self.args, None)
model.load_state_dict(checkpoint["model"])
model.eval()
self.model = model
def forward(self, x):
with torch.no_grad():
z = self.model.feature_extractor(x)
if isinstance(z, tuple):
z = z[0]
c = self.model.feature_aggregator(z)
return z, c
class EmbeddingWriterConfig(argparse.ArgumentParser):
def __init__(self):
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
kwargs = {"action": "store", "type": str, "required": True}
self.add_argument("--input", "-i",
help="Input Directory", **kwargs)
self.add_argument("--output", "-o",
help="Output Directory", **kwargs)
self.add_argument("--model",
help="Path to model checkpoint", **kwargs)
self.add_argument("--split",
help="Dataset Splits", nargs='+', **kwargs)
self.add_argument("--ext", default="wav", required=False,
help="Audio file extension")
self.add_argument("--no-copy-labels", action="store_true",
help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.")
self.add_argument("--use-feat", action="store_true",
help="Use the feature vector ('z') instead of context vector ('c') for features")
self.add_argument("--gpu",
help="GPU to use", default=0, type=int)
class Prediction():
""" Lightweight wrapper around a fairspeech embedding model """
def __init__(self, fname, gpu=0):
self.gpu = gpu
self.model = PretrainedWav2VecModel(fname).cuda(gpu)
def __call__(self, x):
x = torch.from_numpy(x).float().cuda(self.gpu)
with torch.no_grad():
z, c = self.model(x.unsqueeze(0))
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
class H5Writer():
""" Write features as hdf5 file in wav2letter++ compatible format """
def __init__(self, fname):
self.fname = fname
os.makedirs(os.path.dirname(self.fname), exist_ok=True)
def write(self, data):
channel, T = data.shape
with h5py.File(self.fname, "w") as out_ds:
data = data.T.flatten()
out_ds["features"] = data
out_ds["info"] = np.array([16e3 // 160, T, channel])
class EmbeddingDatasetWriter(object):
""" Given a model and a wav2letter++ dataset, pre-compute and store embeddings
Args:
input_root, str :
Path to the wav2letter++ dataset
output_root, str :
Desired output directory. Will be created if non-existent
split, str :
Dataset split
"""
def __init__(self, input_root, output_root, split,
model_fname,
extension="wav",
gpu=0,
verbose=False,
use_feat=False,
):
assert os.path.exists(model_fname)
self.model_fname = model_fname
self.model = Prediction(self.model_fname, gpu)
self.input_root = input_root
self.output_root = output_root
self.split = split
self.verbose = verbose
self.extension = extension
self.use_feat = use_feat
assert os.path.exists(self.input_path), \
"Input path '{}' does not exist".format(self.input_path)
def _progress(self, iterable, **kwargs):
if self.verbose:
return tqdm.tqdm(iterable, **kwargs)
return iterable
def require_output_path(self, fname=None):
path = self.get_output_path(fname)
os.makedirs(path, exist_ok=True)
@property
def input_path(self):
return self.get_input_path()
@property
def output_path(self):
return self.get_output_path()
def get_input_path(self, fname=None):
if fname is None:
return os.path.join(self.input_root, self.split)
return os.path.join(self.get_input_path(), fname)
def get_output_path(self, fname=None):
if fname is None:
return os.path.join(self.output_root, self.split)
return os.path.join(self.get_output_path(), fname)
def copy_labels(self):
self.require_output_path()
labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))))
for fname in tqdm.tqdm(labels):
copy(fname, self.output_path)
@property
def input_fnames(self):
return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension))))
def __len__(self):
return len(self.input_fnames)
def write_features(self):
paths = self.input_fnames
fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \
map(os.path.basename, paths))
for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)):
wav, sr = read_audio(name)
z, c = self.model(wav)
feat = z if self.use_feat else c
writer = H5Writer(target_fname)
writer.write(feat)
def __repr__(self):
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
n_files=len(self), **self.__dict__)
if __name__ == "__main__":
args = EmbeddingWriterConfig().parse_args()
for split in args.split:
writer = EmbeddingDatasetWriter(
input_root=args.input,
output_root=args.output,
split=split,
model_fname=args.model,
gpu=args.gpu,
extension=args.ext,
use_feat=args.use_feat,
)
print(writer)
writer.require_output_path()
print("Writing Features...")
writer.write_features()
print("Done.")
if not args.no_copy_labels:
print("Copying label data...")
writer.copy_labels()
print("Done.")
\ No newline at end of file
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Data pre-processing: build vocabularies and binarize training data.
"""
import argparse
import glob
import os
import soundfile
import random
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
help='percentage of data to use as validation set (between 0 and 1)')
parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
help='if set, path must contain this substring for a file to be included in the manifest')
return parser
def main(args):
assert args.valid_percent >= 0 and args.valid_percent <= 1.
dir_path = os.path.realpath(args.root)
search_path = os.path.join(dir_path, '**/*.' + args.ext)
rand = random.Random(args.seed)
with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
print(dir_path, file=train_f)
print(dir_path, file=valid_f)
for fname in glob.iglob(search_path, recursive=True):
file_path = os.path.realpath(fname)
if args.path_must_contain and args.path_must_contain not in file_path:
continue
frames = soundfile.info(fname).frames
dest = train_f if rand.random() > args.valid_percent else valid_f
print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment