Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
# UniSpeech
This is the official implementation of paper "[UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)". The implementation mainly based on [fairseq](https://github.com/pytorch/fairseq) codebase. We release the training recipes on CommonVoice dataset.
## Requirements and Installation
- Pytorch >= 1.6.0
- python version >= 3.6
``` bash
cd src
pip install soundfile
pip install librosa
pip install pydub
pip install --editable ./
```
## Data Preparation
Download pretraining audio data from [here](https://commonvoice.mozilla.org/datasets). (We use the June 2020 release version in our paper).
Get the wav list and the transcription for each dataset by run:
```
python examples/unispeech/unispeech_manifest.py input_meta_file --dest examples/unispeech/data/LANG
```
Then convert the audio files in common voices to 16k HZ using the commond:
```
python examples/unispeech/adjust_sample_rate.py --wav-path /path/to/wav/ --dest-path /path/to/16kwav/ --input examples/unispeech/data/LANG/*.tsv --output examples/unispeech/data/LANG/*_16k.tsv
```
For the finetuning data, our train/val/test splits are following [this](https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz).
The phoneme transcriptions are generated by [phonemizer](https://github.com/bootphon/phonemizer) to convert texts to phonemes. Then we create .id files using different vocabularies. All our pre-processed data as well as the dictionaries can be downloaded from [here].
## Pretraining
We give the training examples for large model here.
### Stage 1. Pretraining UniSpeech with labeled data.
The following script can be used to pre-train an English model:
```
bash examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
```
To train a multilingual model:
```
bash examples/unispeech/scripts/multilingual_large_pretrain.sh
```
### Stage 2. Continue pre-training with low-resource unlabeled data. (Optional)
After stage 1, you can continue pre-training the UniSpeech model with only contrastive loss:
```
bash examples/unispeech/scripts/continue_pretran.sh
```
### Stage 3. Finetuning with low-resource labeled data.
Finally, fint-tune the model with 1 hour labeled data.
For multilingual models, you can choose to use separate vocabulary (examples/unispeech/data/en/vocab_sep.json) or shared vocabulary (examples/unispeech/data/en/vocab_share.json)
```
bash examples/unispeech/scripts/finetune.sh
```
# WavLM
<!--**Pre-trained models for speech related tasks**-->
[**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) : **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing**
Official PyTorch implementation and pretrained models of WavLM
- Oct 2021: release preprint in [arXiv](https://arxiv.org/pdf/2110.13900.pdf)
## Pre-Trained Models
Model | Pre-training Dataset | Fine-tuning Dataset | Model
|---|---|---|---
WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base.pt?sv=2020-04-08&st=2021-11-05T00%3A35%3A31Z&se=2022-11-06T00%3A35%3A00Z&sr=b&sp=r&sig=JljnRVzyHY6AjHzhVmHV5KyQQCvvGfgp9D2M02oGJBU%3D) <br> [Google Drive](https://drive.google.com/file/d/19-C7SMQvEFAYLG5uc47NX_MY03JCbI4x/view?usp=sharing)
WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-04-08&st=2021-11-05T00%3A34%3A47Z&se=2022-10-06T00%3A34%3A00Z&sr=b&sp=r&sig=Gkf1IByHaIn1t%2FVEd9D6WHjZ3zu%2Fk5eSdoj21UytKro%3D) <br> [Google Drive](https://drive.google.com/file/d/1PlbT_9_B4F9BsD_ija84sUTVw7almNX8/view?usp=sharing)
WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2021-11-22T10%3A03%3A53Z&se=2022-11-23T10%3A03%3A00Z&sr=b&sp=r&sig=3kB8dwTCyIS8YQ7gW5oXmDrXV%2FAaLmoxBS37oPpFsz4%3D) <br> [Google Drive](https://drive.google.com/file/d/1p8nbj16b7YA16sqPZ4E0JUL-oIDUBGwU/view?usp=sharing)
## Load Pre-Trained Models for Inference
```python
import torch
from WavLM import WavLM, WavLMConfig
# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/wavlm.pt')
cfg = WavLMConfig(checkpoint['cfg'])
model = WavLM(cfg)
model.load_state_dict(checkpoint['model'])
model.eval()
# extract the the representation of last layer
wav_input_16khz = torch.randn(1,10000)
rep = model.extract_features(wav_input_16khz)[0]
# extract the the representation of each layer
wav_input_16khz = torch.randn(1,10000)
rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0]
layer_reps = [x.transpose(0, 1) for x, _ in layer_results]
```
## Universal Representation Evaluation on SUPERB
![alt text](WavLM_SUPERB_Results.png)
![alt text](WavLM_SUPERB_Leaderboard.png)
## Downstream Task Performance
We also evaluate our models on typical speech processing benchmarks.
### Speaker Verification
Evaluate on the [VoxCeleb](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube)
| Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
| ------------- |------------- | ---------- | ---------- | ---------- |
| ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
| HuBERT large | Yes| 0.888 |0.912| 1.853 |
| Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
| UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
| WavLM large | Yes | 0.638 | 0.687| 1.457|
| HuBERT large | No| 0.585| 0.654 |1.342|
| Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23|
| UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 |
| **WavLM large** | No | **0.431** | **0.538**| **1.154** |
### Speech Separation
Evaluation on the [LibriCSS](https://github.com/chenzhuo1011/libri_css)
| Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
| [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
| HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3|
| UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
| UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8|
| WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9|
| **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5|
### Speaker Diarization
Evaluation on the [CALLHOME](https://arxiv.org/pdf/1909.06247.pdf)
| Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
| [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
| [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
| HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63|
| HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40|
| UniSpeech-SAT large| 5.93| 10.66| 12.9 |16.48| 23.25| 10.92|
| WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75|
| **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35|
### Speech Recogntion
Evaluate on the [LibriSpeech](https://www.openslr.org/12)
![alt text](WavLM_ASR.PNG)
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
### Reference
If you find our work is useful in your research, please cite the following paper:
``` latex
@article{Chen2021WavLM,
title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Micheal Zeng and Furu Wei},
eprint={2110.13900},
archivePrefix={arXiv},
primaryClass={cs.CL},
year={2021}
}
```
### Contact Information
For help or issues using WavLM models, please submit a GitHub issue.
For other communications related to WavLM, please contact Yu Wu (`yuwu1@microsoft.com`).
# --------------------------------------------------------
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import logging
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GradMultiply,
MultiheadAttention,
SamePad,
init_bert_params,
get_activation_fn,
TransposeLast,
GLU_Linear,
)
logger = logging.getLogger(__name__)
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class WavLMConfig:
def __init__(self, cfg=None):
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self.conv_bias: bool = False # include bias in conv encoder
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
# masking
self.mask_length: int = 10 # mask length)
self.mask_prob: float = 0.65 # probability of replacing a token with mask
self.mask_selection: str = "static" # how to choose mask length
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
self.no_mask_overlap: bool = False # whether to allow masks to overlap
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
# channel masking
self.mask_channel_length: int = 10 # length of the mask for features (channels)
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class WavLM(nn.Module):
def __init__(
self,
cfg: WavLMConfig,
) -> None:
super().__init__()
logger.info(f"WavLM Config: {cfg.__dict__}")
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
def apply_mask(self, x, padding_mask):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_padding_mask(
self, features: torch.Tensor, padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
ret_layer_results: bool = False,
):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
if mask:
x, mask_indices = self.apply_mask(
features, padding_mask
)
else:
x = features
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1
)
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
feature = res["features"] if ret_conv else res["x"]
if ret_layer_results:
feature = (feature, res["layer_results"])
return feature, res["padding_mask"]
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
conv_type: str = "default"
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
self.conv_type = conv_type
if self.conv_type == "default":
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
elif self.conv_type == "conv2d":
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3
(dim, k, stride) = cl
self.conv_layers.append(
torch.nn.Conv2d(in_d, dim, k, stride)
)
self.conv_layers.append(torch.nn.ReLU())
in_d = dim
elif self.conv_type == "custom":
in_d = 1
idim = 80
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3
(dim, k, stride) = cl
self.conv_layers.append(
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
)
self.conv_layers.append(
torch.nn.LayerNorm([dim, idim])
)
self.conv_layers.append(torch.nn.ReLU())
in_d = dim
if (i + 1) % 2 == 0:
self.conv_layers.append(
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
)
idim = int(math.ceil(idim / 2))
else:
pass
def forward(self, x, mask=None):
# BxT -> BxCxT
x = x.unsqueeze(1)
if self.conv_type == "custom":
for conv in self.conv_layers:
if isinstance(conv, nn.LayerNorm):
x = x.transpose(1, 2)
x = conv(x).transpose(1, 2)
else:
x = conv(x)
x = x.transpose(2, 3).contiguous()
x = x.view(x.size(0), -1, x.size(-1))
else:
for conv in self.conv_layers:
x = conv(x)
if self.conv_type == "conv2d":
b, c, t, f = x.size()
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
return x
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
)
for i in range(args.encoder_layers)
]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x += x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
self_attn_mask=streaming_mask, pos_bias=pos_bias)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
# --------------------------------------------------------
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import warnings
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.nn import Parameter
import torch.nn.functional as F
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
"""Swish function
"""
def __init__(self):
"""Construct an MultiHeadedAttention object."""
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position,
bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if (
not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and self.q_head_dim == self.head_dim
):
assert key is not None and value is not None
assert attn_mask is None
attn_mask_rel_pos = None
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos:
query_layer = query.transpose(0, 1)
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
query_layer = query_layer.view(*new_x_shape)
query_layer = query_layer.permute(0, 2, 1, 3)
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
k_proj_bias = self.k_proj.bias
if k_proj_bias is None:
k_proj_bias = torch.zeros_like(self.q_proj.bias)
x, attn = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training,
# self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask_rel_pos,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
return x, attn, position_bias
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
position_bias = position_bias.view(attn_weights.size())
attn_weights = attn_weights + position_bias
attn_weights_float = F.softmax(
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
trigger:
- master
pool:
vmImage: 'windows-latest'
steps:
- script: echo Hello, world!
displayName: 'Run a one-line script'
- script: |
echo Add other tasks to build, test, and deploy your project.
echo See https://aka.ms/yaml
displayName: 'Run a multi-line script'
- task: CredScan@2
inputs:
toolMajorVersion: 'V2'
- task: Semmle@0
env:
SYSTEM_ACCESSTOKEN: $(PATToken)
inputs:
sourceCodeDirectory: '$(Build.SourcesDirectory)'
language: 'python'
includeNodeModules: true
querySuite: 'Recommended'
timeout: '1800'
ram: '16384'
addProjectDirToScanningExclusionList: true
- task: ComponentGovernanceComponentDetection@0
inputs:
scanType: 'Register'
verbosity: 'Verbose'
alertWarningLevel: 'High'
- task: PublishSecurityAnalysisLogs@2
inputs:
ArtifactName: 'CodeAnalysisLogs'
ArtifactType: 'Container'
AllTools: true
ToolLogsNotFoundAction: 'Standard'
## Pre-training Representations for Speaker Diarization
### Downstream Model
[EEND-vector-clustering](https://arxiv.org/abs/2105.09040)
### Pre-trained models
- It should be noted that the diarization system is trained on 8k audio data.
| Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER |
| ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- |
| EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 |
| [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 |
### How to use?
#### Environment Setup
1. `pip install --require-hashes -r requirements.txt`
2. Install fairseq code
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
#### Example
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
2. Then, run the following codes:
- The wav file is the multi-talker simulated speech from Librispeech corpus.
3. The output will be written in `out.rttm` by default.
```bash
python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path
```
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
# inference options
est_nspk: 1
sil_spk_th: 0.05
ahc_dis_th: 1.0
clink_dis: 1.0e+4
model:
n_speakers: 3
all_n_speakers: 0
feat_dim: 1024
n_units: 256
n_heads: 8
n_layers: 6
dropout_rate: 0.1
spk_emb_dim: 256
sr: 8000
frame_shift: 320
frame_size: 200
context_size: 0
subsampling: 1
feat_type: "config/unispeech_sat.th"
feature_selection: "hidden_states"
interpolate_mode: "linear"
dataset:
chunk_size: 750
frame_shift: 320
sampling_rate: 8000
subsampling: 1
num_speakers: 3
import sys
import h5py
import soundfile as sf
import fire
import math
import yamlargparse
import numpy as np
from torch.utils.data import DataLoader
import torch
from utils.utils import parse_config_or_kwargs
from utils.dataset import DiarizationDataset
from models.models import TransformerDiarization
from scipy.signal import medfilt
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial import distance
from utils.kaldi_data import KaldiData
def get_cl_sil(args, acti, cls_num):
n_chunks = len(acti)
mean_acti = np.array([np.mean(acti[i], axis=0)
for i in range(n_chunks)]).flatten()
n = args.num_speakers
sil_spk_th = args.sil_spk_th
cl_lst = []
sil_lst = []
for chunk_idx in range(n_chunks):
if cls_num is not None:
if args.num_speakers > cls_num:
mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx]
for s_loc_idx in range(n)])
min_idx = np.argmin(mean_acti_bi)
mean_acti[n * chunk_idx + min_idx] = 0.0
for s_loc_idx in range(n):
a = n * chunk_idx + (s_loc_idx + 0) % n
b = n * chunk_idx + (s_loc_idx + 1) % n
if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th:
cl_lst.append((a, b))
else:
if mean_acti[a] <= sil_spk_th:
sil_lst.append(a)
return cl_lst, sil_lst
def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst):
org_svec_len = len(svec)
svec = np.delete(svec, sil_lst, 0)
# update cl_lst idx
_tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)]
cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst]
distMat = distance.cdist(svec, svec, metric='euclidean')
for cl in cl_lst:
distMat[cl[0], cl[1]] = args.clink_dis
distMat[cl[1], cl[0]] = args.clink_dis
clusterer = AgglomerativeClustering(
n_clusters=cls_num,
affinity='precomputed',
linkage='average',
distance_threshold=ahc_dis_th)
clusterer.fit(distMat)
if cls_num is not None:
print("oracle n_clusters is known")
else:
print("oracle n_clusters is unknown")
print("estimated n_clusters by constraind AHC: {}"
.format(len(np.unique(clusterer.labels_))))
cls_num = len(np.unique(clusterer.labels_))
sil_lab = cls_num
insert_sil_lab = [sil_lab for i in range(len(sil_lst))]
insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))]
print("insert_sil_lab : {}".format(insert_sil_lab))
print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx))
clslab = np.insert(clusterer.labels_,
insert_sil_lab_idx,
insert_sil_lab).reshape(-1, args.num_speakers)
print("clslab : {}".format(clslab))
return clslab, cls_num
def merge_act_max(act, i, j):
for k in range(len(act)):
act[k, i] = max(act[k, i], act[k, j])
act[k, j] = 0.0
return act
def merge_acti_clslab(args, acti, clslab, cls_num):
sil_lab = cls_num
for i in range(len(clslab)):
_lab = clslab[i].reshape(-1, 1)
distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64)
for j in range(len(distM)):
distM[j][:j] = -1
idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1)
merge_done = []
for j in idx_lst[0]:
for k in (np.where(distM[j] == 0))[0]:
if j != k and clslab[i, j] != sil_lab and k not in merge_done:
print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k))
acti[i] = merge_act_max(acti[i], j, k)
clslab[i, k] = sil_lab
merge_done.append(j)
return acti, clslab
def stitching(args, acti, clslab, cls_num):
n_chunks = len(acti)
s_loc = args.num_speakers
sil_lab = cls_num
s_tot = max(cls_num, s_loc-1)
# Extend the max value of s_loc_idx to s_tot+1
add_acti = []
for chunk_idx in range(n_chunks):
zeros = np.zeros((len(acti[chunk_idx]), s_tot+1))
if s_tot+1 > s_loc:
zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx]
else:
zeros = acti[chunk_idx]
add_acti.append(zeros)
acti = np.array(add_acti)
out_chunks = []
for chunk_idx in range(n_chunks):
# Make sloci2lab_dct.
# key: s_loc_idx
# value: estimated label by clustering or sil_lab
cls_set = set()
for s_loc_idx in range(s_tot+1):
cls_set.add(s_loc_idx)
sloci2lab_dct = {}
for s_loc_idx in range(s_tot+1):
if s_loc_idx < s_loc:
sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx]
if clslab[chunk_idx][s_loc_idx] in cls_set:
cls_set.remove(clslab[chunk_idx][s_loc_idx])
else:
if clslab[chunk_idx][s_loc_idx] != sil_lab:
raise ValueError
else:
sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc]
# Sort by label value
sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1])
# Select sil_lab_idx
sil_lab_idx = None
for idx_lab in sloci2lab_lst:
if idx_lab[1] == sil_lab:
sil_lab_idx = idx_lab[0]
break
if sil_lab_idx is None:
raise ValueError
# Get swap_idx
# [idx of label(0), idx of label(1), ..., idx of label(s_tot)]
swap_idx = [sil_lab_idx for j in range(s_tot+1)]
for lab in range(s_tot+1):
for idx_lab in sloci2lab_lst:
if lab == idx_lab[1]:
swap_idx[lab] = idx_lab[0]
print("swap_idx {}".format(swap_idx))
swap_acti = acti[chunk_idx][:, swap_idx]
swap_acti = np.delete(swap_acti, sil_lab, 1)
out_chunks.append(swap_acti)
return out_chunks
def prediction(num_speakers, net, wav_list, chunk_len_list):
acti_lst = []
svec_lst = []
len_list = []
with torch.no_grad():
for wav, chunk_len in zip(wav_list, chunk_len_list):
wav = wav.to('cuda')
outputs = net.batch_estimate(torch.unsqueeze(wav, 0))
ys = outputs[0]
for i in range(num_speakers):
spkivecs = outputs[i+1]
svec_lst.append(spkivecs[0].cpu().detach().numpy())
acti = ys[0][-chunk_len:].cpu().detach().numpy()
acti_lst.append(acti)
len_list.append(chunk_len)
acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers
svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim
len_arr = np.array(len_list) # chunk_num
return acti_arr, svec_arr, len_arr
def cluster(args, conf, acti_arr, svec_arr, len_arr):
acti_list = []
n_chunks = len_arr.shape[0]
start = 0
for i in range(n_chunks):
chunk_len = len_arr[i]
acti_list.append(acti_arr[start: start+chunk_len])
start += chunk_len
acti = np.array(acti_list)
svec = svec_arr
# initialize clustering setting
cls_num = None
ahc_dis_th = args.ahc_dis_th
# Get cannot-link index list and silence index list
cl_lst, sil_lst = get_cl_sil(args, acti, cls_num)
n_samples = n_chunks * args.num_speakers - len(sil_lst)
min_n_samples = 2
if cls_num is not None:
min_n_samples = cls_num
if n_samples >= min_n_samples:
# clustering (if cls_num is None, update cls_num)
clslab, cls_num =\
clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst)
# merge
acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num)
# stitching
out_chunks = stitching(args, acti, clslab, cls_num)
else:
out_chunks = acti
outdata = np.vstack(out_chunks)
# Saving the resuts
return outdata
def make_rttm(args, conf, cluster_data):
args.frame_shift = conf['model']['frame_shift']
args.subsampling = conf['model']['subsampling']
args.sampling_rate = conf['dataset']['sampling_rate']
with open(args.out_rttm_file, 'w') as wf:
a = np.where(cluster_data > args.threshold, 1, 0)
if args.median > 1:
a = medfilt(a, (args.median, 1))
for spkid, frames in enumerate(a.T):
frames = np.pad(frames, (1, 1), 'constant')
changes, = np.where(np.diff(frames, axis=0) != 0)
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
print(fmt.format(
args.session,
s * args.frame_shift * args.subsampling / args.sampling_rate,
(e - s) * args.frame_shift * args.subsampling / args.sampling_rate,
args.session + "_" + str(spkid)), file=wf)
def main(args):
conf = parse_config_or_kwargs(args.config_path)
num_speakers = conf['dataset']['num_speakers']
args.num_speakers = num_speakers
# Prepare model
model_parameter_dict = torch.load(args.model_init)['model']
model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0]
conf['model']['all_n_speakers'] = model_all_n_speakers
net = TransformerDiarization(**conf['model'])
net.load_state_dict(model_parameter_dict, strict=False)
net.eval()
net = net.to("cuda")
audio, sr = sf.read(args.wav_path, dtype="float32")
audio_len = audio.shape[0]
chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling']
scale_ratio = int(frame_shift * subsampling)
chunk_audio_size = chunk_size * scale_ratio
wav_list, chunk_len_list = [], []
for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)):
start, end = i*chunk_audio_size, (i+1)*chunk_audio_size
if end > audio_len:
chunk_len_list.append(int((audio_len-start) / scale_ratio))
end = audio_len
start = max(0, audio_len - chunk_audio_size)
else:
chunk_len_list.append(chunk_size)
wav_list.append(audio[start:end])
wav_list = [torch.from_numpy(wav).float() for wav in wav_list]
acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list)
cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr)
make_rttm(args, conf, cluster_data)
if __name__ == '__main__':
parser = yamlargparse.ArgumentParser(description='decoding')
parser.add_argument('--wav_path',
help='the input wav path',
default="tmp/mix_0000496.wav")
parser.add_argument('--config_path',
help='config file path',
default="config/infer_est_nspk1.yaml")
parser.add_argument('--model_init',
help='model initialize path',
default="")
parser.add_argument('--sil_spk_th', default=0.05, type=float)
parser.add_argument('--ahc_dis_th', default=1.0, type=float)
parser.add_argument('--clink_dis', default=1.0e+4, type=float)
parser.add_argument('--session', default='Anonymous', help='the name of the output speaker')
parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file')
parser.add_argument('--threshold', default=0.4, type=float)
parser.add_argument('--median', default=25, type=int)
args = parser.parse_args()
main(args)
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchaudio.transforms as trans
from collections import OrderedDict
from itertools import permutations
from models.transformer import TransformerEncoder
from .utils import UpstreamExpert
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
"""
P: number of permutation
T: number of frames
C: number of speakers (classes)
B: mini-batch size
"""
def batch_pit_loss_parallel(outputs, labels, ilens=None):
""" calculate the batch pit loss parallelly
Args:
outputs (torch.Tensor): B x T x C
labels (torch.Tensor): B x T x C
ilens (torch.Tensor): B
Returns:
perm (torch.Tensor): permutation for outputs (Batch, num_spk)
loss
"""
if ilens is None:
mask, scale = 1.0, outputs.shape[1]
else:
scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
mask = outputs.new_zeros(outputs.size()[:-1])
for i, chunk_len in enumerate(ilens):
mask[i, :chunk_len] += 1.0
mask /= scale
def loss_func(output, label):
# return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim())))
return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)
def pair_loss(outputs, labels, permutation):
return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)
device = outputs.device
num_spk = outputs.shape[-1]
all_permutations = list(permutations(range(num_spk)))
losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
loss, perm = torch.min(losses, dim=1)
perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
return torch.mean(loss), perm
def fix_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
# remove 'module.' of DataParallel
k = k[7:]
if k.startswith('net.'):
# remove 'net.' of PadertorchModel
k = k[4:]
new_state_dict[k] = v
return new_state_dict
class TransformerDiarization(nn.Module):
def __init__(self,
n_speakers,
all_n_speakers,
feat_dim,
n_units,
n_heads,
n_layers,
dropout_rate,
spk_emb_dim,
sr=8000,
frame_shift=256,
frame_size=1024,
context_size=0,
subsampling=1,
feat_type='fbank',
feature_selection='default',
interpolate_mode='linear',
update_extract=False,
feature_grad_mult=1.0
):
super(TransformerDiarization, self).__init__()
self.context_size = context_size
self.subsampling = subsampling
self.feat_type = feat_type
self.feature_selection = feature_selection
self.sr = sr
self.frame_shift = frame_shift
self.interpolate_mode = interpolate_mode
self.update_extract = update_extract
self.feature_grad_mult = feature_grad_mult
if feat_type == 'fbank':
self.feature_extract = trans.MelSpectrogram(sample_rate=sr,
n_fft=frame_size,
win_length=frame_size,
hop_length=frame_shift,
f_min=0.0,
f_max=sr // 2,
pad=0,
n_mels=feat_dim)
else:
self.feature_extract = UpstreamExpert(feat_type)
# self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
# for param in self.feature_extract.parameters():
# param.requires_grad = False
self.resample = trans.Resample(orig_freq=sr, new_freq=16000)
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
feat_dim = feat_dim * (self.context_size*2 + 1)
self.enc = TransformerEncoder(
feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate)
self.linear = nn.Linear(n_units, n_speakers)
for i in range(n_speakers):
setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim))
self.n_speakers = n_speakers
self.embed = nn.Embedding(all_n_speakers, spk_emb_dim)
self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def fix_except_embedding(self, requires_grad=False):
for name, param in self.named_parameters():
if 'embed' not in name:
param.requires_grad = requires_grad
def modfy_emb(self, weight):
self.embed = nn.Embedding.from_pretrained(weight)
def splice(self, data, context_size):
# data: B x feat_dim x time_len
data = torch.unsqueeze(data, -1)
kernel_size = context_size*2 + 1
splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0))
return splice_data
def get_feat(self, xs):
wav_len = xs.shape[-1]
chunk_size = int(wav_len / self.frame_shift)
chunk_size = int(chunk_size / self.subsampling)
self.feature_extract.eval()
if self.update_extract:
xs = self.resample(xs)
feature = self.feature_extract([sample for sample in xs])
else:
with torch.no_grad():
if self.feat_type == 'fbank':
feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len
feature = feature.log()
else:
xs = self.resample(xs)
feature = self.feature_extract([sample for sample in xs])
if self.feat_type != "fbank" and self.feat_type != "mfcc":
feature = feature[self.feature_selection]
if isinstance(feature, (list, tuple)):
feature = torch.stack(feature, dim=0)
else:
feature = feature.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
feature = (norm_weights * feature).sum(dim=0)
feature = torch.transpose(feature, 1, 2) + 1e-6
feature = self.instance_norm(feature)
feature = self.splice(feature, self.context_size)
feature = feature[:, :, ::self.subsampling]
feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode)
feature = torch.transpose(feature, 1, 2)
if self.feature_grad_mult != 1.0:
feature = GradMultiply.apply(feature, self.feature_grad_mult)
return feature
def forward(self, inputs):
if isinstance(inputs, list):
xs = inputs[0]
else:
xs = inputs
feature = self.get_feat(xs)
pad_shape = feature.shape
emb = self.enc(feature)
ys = self.linear(emb)
ys = ys.reshape(pad_shape[0], pad_shape[1], -1)
spksvecs = []
for i in range(self.n_speakers):
spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb)
spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1)
spksvecs.append(spkivecs)
return ys, spksvecs
def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True):
ts = inputs[1]
ss = inputs[2]
ns = inputs[3]
ilens = inputs[4]
ilens = [ilen.item() for ilen in ilens]
pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
if cal_spk_loss:
spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens)
else:
spk_loss = torch.tensor(0.0).to(pit_loss.device)
alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon)
return {'spk_loss':spk_loss,
'pit_loss': pit_loss}
def batch_estimate(self, xs):
out = self(xs)
ys = out[0]
spksvecs = out[1]
spksvecs = list(zip(*spksvecs))
outputs = [
self.estimate(spksvec, y)
for (spksvec, y) in zip(spksvecs, ys)]
outputs = list(zip(*outputs))
return outputs
def batch_estimate_with_perm(self, xs, ts, ilens=None):
out = self(xs)
ys = out[0]
if ts[0].shape[1] > ys[0].shape[1]:
# e.g. the case of training 3-spk model with 4-spk data
add_dim = ts[0].shape[1] - ys[0].shape[1]
y_device = ys[0].device
zeros = [torch.zeros(ts[0].shape).to(y_device)
for i in range(len(ts))]
_ys = []
for zero, y in zip(zeros, ys):
_zero = zero
_zero[:, :-add_dim] = y
_ys.append(_zero)
_, sigmas = batch_pit_loss_parallel(_ys, ts, ilens)
else:
_, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
spksvecs = out[1]
spksvecs = list(zip(*spksvecs))
outputs = [self.estimate(spksvec, y)
for (spksvec, y) in zip(spksvecs, ys)]
outputs = list(zip(*outputs))
zs = outputs[0]
if ts[0].shape[1] > ys[0].shape[1]:
# e.g. the case of training 3-spk model with 4-spk data
add_dim = ts[0].shape[1] - ys[0].shape[1]
z_device = zs[0].device
zeros = [torch.zeros(ts[0].shape).to(z_device)
for i in range(len(ts))]
_zs = []
for zero, z in zip(zeros, zs):
_zero = zero
_zero[:, :-add_dim] = z
_zs.append(_zero)
zs = _zs
outputs[0] = zs
outputs.append(sigmas)
# outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas]
return outputs
def estimate(self, spksvec, y):
outputs = []
z = torch.sigmoid(y.transpose(1, 0))
outputs.append(z.transpose(1, 0))
for spkid, spkvec in enumerate(spksvec):
norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
# Normalize speaker vectors before weighted average
spkvec = torch.mul(
spkvec.transpose(1, 0), norm_spkvec_inv
).transpose(1, 0)
wavg_spkvec = torch.mul(
spkvec.transpose(1, 0), z[spkid]
).transpose(1, 0)
sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
outputs.append(nmz_wavg_spkvec)
# outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...]
return outputs
def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
'''
spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
ys (torch.Tensor): B x T x 3
ts (torch.Tensor): B x T x 3
ss (torch.Tensor): B x 3
sigmas (torch.Tensor): B x 3
ns (torch.Tensor): B x total_spk_num x 1
ilens (List): B
'''
chunk_spk_num = len(spksvecs) # 3
len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
for i, len_val in enumerate(ilens):
len_mask[i,:len_val] += 1.0
ts = ts * len_mask.unsqueeze(-1)
len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T
spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
# Normalize speaker vectors before weighted average
spk_vecs = F.normalize(spk_vecs, dim=-1)
ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1
weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
weight_spk_vec *= len_mask.unsqueeze(-1)
sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
norm_spk_vec = F.normalize(sum_spk_vec, dim=1)
embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta)
label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
label[label==-1] = 0
valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3
valid_spk_loss_num = torch.sum(valid_spk_mask).item()
if valid_spk_loss_num > 0:
loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
# uncomment the line below, the loss result is same as batch_spk_loss
# loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
return torch.sum(loss)
else:
return torch.tensor(0.0).to(ys.device)
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
class NoamScheduler(_LRScheduler):
""" learning rate scheduler used in the transformer
See https://arxiv.org/pdf/1706.03762.pdf
lrate = d_model**(-0.5) * \
min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
Scaling factor is implemented as in
http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
"""
def __init__(
self, optimizer, d_model, warmup_steps, tot_step, scale,
last_epoch=-1
):
self.d_model = d_model
self.warmup_steps = warmup_steps
self.tot_step = tot_step
self.scale = scale
super(NoamScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
self.last_epoch = max(1, self.last_epoch)
step_num = self.last_epoch
val = self.scale * self.d_model ** (-0.5) * \
min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))
return [base_lr / base_lr * val for base_lr in self.base_lrs]
class MultiHeadSelfAttention(nn.Module):
""" Multi head "self" attention layer
"""
def __init__(self, n_units, h=8, dropout_rate=0.1):
super(MultiHeadSelfAttention, self).__init__()
self.linearQ = nn.Linear(n_units, n_units)
self.linearK = nn.Linear(n_units, n_units)
self.linearV = nn.Linear(n_units, n_units)
self.linearO = nn.Linear(n_units, n_units)
self.d_k = n_units // h
self.h = h
self.dropout = nn.Dropout(p=dropout_rate)
# attention for plot
self.att = None
def forward(self, x, batch_size):
# x: (BT, F)
q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k)
k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k)
v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k)
scores = torch.matmul(
q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k)
# scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T)
self.att = F.softmax(scores, dim=3)
p_att = self.dropout(self.att)
x = torch.matmul(p_att, v.transpose(1, 2))
x = x.transpose(1, 2).reshape(-1, self.h * self.d_k)
return self.linearO(x)
class PositionwiseFeedForward(nn.Module):
""" Positionwise feed-forward layer
"""
def __init__(self, n_units, d_units, dropout_rate):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(n_units, d_units)
self.linear2 = nn.Linear(d_units, n_units)
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class PositionalEncoding(nn.Module):
""" Positional encoding function
"""
def __init__(self, n_units, dropout_rate, max_len):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout_rate)
positions = np.arange(0, max_len, dtype='f')[:, None]
dens = np.exp(
np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units))
self.enc = np.zeros((max_len, n_units), dtype='f')
self.enc[:, ::2] = np.sin(positions * dens)
self.enc[:, 1::2] = np.cos(positions * dens)
self.scale = np.sqrt(n_units)
def forward(self, x):
x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]])
return self.dropout(x)
class TransformerEncoder(nn.Module):
def __init__(self, idim, n_layers, n_units,
e_units=2048, h=8, dropout_rate=0.1):
super(TransformerEncoder, self).__init__()
self.linear_in = nn.Linear(idim, n_units)
# self.lnorm_in = nn.LayerNorm(n_units)
self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000)
self.n_layers = n_layers
self.dropout = nn.Dropout(p=dropout_rate)
for i in range(n_layers):
setattr(self, '{}{:d}'.format("lnorm1_", i),
nn.LayerNorm(n_units))
setattr(self, '{}{:d}'.format("self_att_", i),
MultiHeadSelfAttention(n_units, h, dropout_rate))
setattr(self, '{}{:d}'.format("lnorm2_", i),
nn.LayerNorm(n_units))
setattr(self, '{}{:d}'.format("ff_", i),
PositionwiseFeedForward(n_units, e_units, dropout_rate))
self.lnorm_out = nn.LayerNorm(n_units)
def forward(self, x):
# x: (B, T, F) ... batch, time, (mel)freq
BT_size = x.shape[0] * x.shape[1]
# e: (BT, F)
e = self.linear_in(x.reshape(BT_size, -1))
# Encoder stack
for i in range(self.n_layers):
# layer normalization
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
# self-attention
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0])
# residual
e = e + self.dropout(s)
# layer normalization
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
# positionwise feed-forward
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
# residual
e = e + self.dropout(s)
# final layer normalization
# output: (BT, F)
return self.lnorm_out(e)
import torch
import fairseq
from packaging import version
import torch.nn.functional as F
from fairseq import tasks
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import OmegaConf
from s3prl.upstream.interfaces import UpstreamBase
from torch.nn.utils.rnn import pad_sequence
def load_model(filepath):
state = torch.load(filepath, map_location=lambda storage, loc: storage)
# state = load_checkpoint_to_cpu(filepath)
state["cfg"] = OmegaConf.create(state["cfg"])
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
model = task.build_model(cfg.model)
return model, cfg, task
###################
# UPSTREAM EXPERT #
###################
class UpstreamExpert(UpstreamBase):
def __init__(self, ckpt, **kwargs):
super().__init__(**kwargs)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = load_model(ckpt)
self.model = model
self.task = task
if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
self.add_hook(
f"{module_name}[{module_id}]",
lambda input, output: input[0].transpose(0, 1),
)
self.add_hook("self.model.encoder", lambda input, output: output[0])
def forward(self, wavs):
if self.task.cfg.normalize:
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
return {
"default": features,
}
SoundFile==0.10.3.post1 \
--hash=sha256:2d17e0a6fc2af0d6c1d868bafa5ec80aae6e186a97fec8db07ad6af29842fbc7 \
--hash=sha256:4555438c2c4f02b39fea2ed40f6ddeda88a80cd1ee9dd129be4d5f5134698cc2 \
--hash=sha256:490cff42650733d1832728b937fe99fa1802896f5ef4d61bcf78cf7ebecb107b \
--hash=sha256:5e342ee293b896d31da67617fe65d0bdca217af193991b0cb6052353b1e0e506 \
--hash=sha256:b361d4ac1519a2e516cabafa6bf7e93492f999f35d7d25350cd87fdc3e5cb27e
fire==0.4.0 \
--hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
sentencepiece==0.1.96 \
--hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
--hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
--hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
--hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
--hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
--hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
--hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
--hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
--hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
--hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
--hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
--hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
--hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
--hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
--hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
--hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
--hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
--hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
--hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
--hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
--hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
--hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
--hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
--hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
--hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
--hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
--hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
--hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
--hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
--hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
--hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
--hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
--hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
--hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
--hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
--hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
--hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
--hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
--hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
--hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
--hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
--hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
tqdm==4.62.0 \
--hash=sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6 \
--hash=sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340
PyYAML==5.4.1 \
--hash=sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf \
--hash=sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696 \
--hash=sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393 \
--hash=sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77 \
--hash=sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922 \
--hash=sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5 \
--hash=sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8 \
--hash=sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10 \
--hash=sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc \
--hash=sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018 \
--hash=sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e \
--hash=sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253 \
--hash=sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347 \
--hash=sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183 \
--hash=sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541 \
--hash=sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb \
--hash=sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185 \
--hash=sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc \
--hash=sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db \
--hash=sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa \
--hash=sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46 \
--hash=sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122 \
--hash=sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b \
--hash=sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63 \
--hash=sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df \
--hash=sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc \
--hash=sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247 \
--hash=sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6 \
--hash=sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0
h5py==3.3.0 \
--hash=sha256:09e78cefdef0b7566ab66366c5c7d9984c7b23142245bd51b82b744ad1eebf65 \
--hash=sha256:13355234c004ff8bd819f7d3420188aa1936b17d7f8470d622974a373421b7a5 \
--hash=sha256:5e2f22e66a3fb1815405cfe5711670450c973b8552507c535a546a23a468af3d \
--hash=sha256:7ca7d23ebbdd59a4be9b4820de52fe67adc74e6a44d5084881305461765aac47 \
--hash=sha256:89d7e10409b62fed81c571e35798763cb8375442b98f8ebfc52ba41ac019e081 \
--hash=sha256:8e09b682e4059c8cd259ddcc34bee35d639b9170105efeeae6ad195e7c1cea7a \
--hash=sha256:baef1a2cdef287a83e7f95ce9e0f4d762a9852fe7117b471063442c78b973695 \
--hash=sha256:e0dac887d779929778b3cfd13309a939359cc9e74756fc09af7c527a82797186 \
--hash=sha256:e0ea3330bf136f8213e43db67448994046ce501585dddc7ea4e8ceef0ef1600c \
--hash=sha256:f3bba8ffddd1fd2bf06127c5ff7b73f022cc1c8b7164355ddc760dc3f8570136
yamlargparse==1.31.1 \
--hash=sha256:2c09fc8e20c147d074f765512b880757a6fea669d57a3dc672a5e1be6c68c667
sklearn==0.0 \
--hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
matplotlib==3.4.2 \
--hash=sha256:0bea5ec5c28d49020e5d7923c2725b837e60bc8be99d3164af410eb4b4c827da \
--hash=sha256:1c1779f7ab7d8bdb7d4c605e6ffaa0614b3e80f1e3c8ccf7b9269a22dbc5986b \
--hash=sha256:21b31057bbc5e75b08e70a43cefc4c0b2c2f1b1a850f4a0f7af044eb4163086c \
--hash=sha256:32fa638cc10886885d1ca3d409d4473d6a22f7ceecd11322150961a70fab66dd \
--hash=sha256:3a5c18dbd2c7c366da26a4ad1462fe3e03a577b39e3b503bbcf482b9cdac093c \
--hash=sha256:5826f56055b9b1c80fef82e326097e34dc4af8c7249226b7dd63095a686177d1 \
--hash=sha256:6382bc6e2d7e481bcd977eb131c31dee96e0fb4f9177d15ec6fb976d3b9ace1a \
--hash=sha256:6475d0209024a77f869163ec3657c47fed35d9b6ed8bccba8aa0f0099fbbdaa8 \
--hash=sha256:6a6a44f27aabe720ec4fd485061e8a35784c2b9ffa6363ad546316dfc9cea04e \
--hash=sha256:7a58f3d8fe8fac3be522c79d921c9b86e090a59637cb88e3bc51298d7a2c862a \
--hash=sha256:7ad19f3fb6145b9eb41c08e7cbb9f8e10b91291396bee21e9ce761bb78df63ec \
--hash=sha256:85f191bb03cb1a7b04b5c2cca4792bef94df06ef473bc49e2818105671766fee \
--hash=sha256:956c8849b134b4a343598305a3ca1bdd3094f01f5efc8afccdebeffe6b315247 \
--hash=sha256:a9d8cb5329df13e0cdaa14b3b43f47b5e593ec637f13f14db75bb16e46178b05 \
--hash=sha256:b1d5a2cedf5de05567c441b3a8c2651fbde56df08b82640e7f06c8cd91e201f6 \
--hash=sha256:b26535b9de85326e6958cdef720ecd10bcf74a3f4371bf9a7e5b2e659c17e153 \
--hash=sha256:c541ee5a3287efe066bbe358320853cf4916bc14c00c38f8f3d8d75275a405a9 \
--hash=sha256:d8d994cefdff9aaba45166eb3de4f5211adb4accac85cbf97137e98f26ea0219 \
--hash=sha256:df815378a754a7edd4559f8c51fc7064f779a74013644a7f5ac7a0c31f875866
torchaudio==0.9.0 \
--hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
--hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
--hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
--hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
--hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
--hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
--hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
--hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
--hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
--hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
--hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
--hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
--hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
--hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
s3prl==0.3.1 \
--hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav
\ No newline at end of file
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ dataset.py ]
# Synopsis [ the speaker diarization dataset ]
# Source [ Refactored from https://github.com/hitachi-speech/EEND ]
# Author [ Jiatong Shi ]
# Copyright [ Copyright(c), Johns Hopkins University ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import io
import os
import subprocess
import sys
# -------------#
import numpy as np
import soundfile as sf
import torch
from torch.nn.utils.rnn import pad_sequence
# -------------#
from torch.utils.data.dataset import Dataset
# -------------#
def _count_frames(data_len, size, step):
# no padding at edges, last remaining samples are ignored
return int((data_len - size + step) / step)
def _gen_frame_indices(data_length, size=2000, step=2000):
i = -1
for i in range(_count_frames(data_length, size, step)):
yield i * step, i * step + size
if i * step + size < data_length:
if data_length - (i + 1) * step > 0:
if i == -1:
yield (i + 1) * step, data_length
else:
yield data_length - size, data_length
def _gen_chunk_indices(data_len, chunk_size):
step = chunk_size
start = 0
while start < data_len:
end = min(data_len, start + chunk_size)
yield start, end
start += step
#######################
# Diarization Dataset #
#######################
class DiarizationDataset(Dataset):
def __init__(
self,
mode,
data_dir,
chunk_size=2000,
frame_shift=256,
sampling_rate=16000,
subsampling=1,
use_last_samples=True,
num_speakers=3,
filter_spk=False
):
super(DiarizationDataset, self).__init__()
self.mode = mode
self.data_dir = data_dir
self.chunk_size = chunk_size
self.frame_shift = frame_shift
self.subsampling = subsampling
self.n_speakers = num_speakers
self.chunk_indices = [] if mode != "test" else {}
self.data = KaldiData(self.data_dir)
self.all_speakers = sorted(self.data.spk2utt.keys())
self.all_n_speakers = len(self.all_speakers)
# make chunk indices: filepath, start_frame, end_frame
for rec in self.data.wavs:
data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift)
data_len = int(data_len / self.subsampling)
if mode == "test":
self.chunk_indices[rec] = []
if mode != "test":
for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size):
self.chunk_indices.append(
(rec, st * self.subsampling, ed * self.subsampling)
)
else:
for st, ed in _gen_chunk_indices(data_len, chunk_size):
self.chunk_indices[rec].append(
(rec, st, ed)
)
if mode != "test":
if filter_spk:
self.filter_spk()
print(len(self.chunk_indices), " chunks")
else:
self.rec_list = list(self.chunk_indices.keys())
print(len(self.rec_list), " recordings")
def __len__(self):
return (
len(self.rec_list)
if type(self.chunk_indices) == dict
else len(self.chunk_indices)
)
def filter_spk(self):
# filter the spk in spk2utt but will not be used in training
# i.e. the chunks contains more spk than self.n_speakers
occur_spk_set = set()
new_chunk_indices = [] # filter the chunk that more than self.num_speakers
for idx in range(self.__len__()):
rec, st, ed = self.chunk_indices[idx]
filtered_segments = self.data.segments[rec]
# all the speakers in this recording not the chunk
speakers = np.unique(
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
).tolist()
n_speakers = self.n_speakers
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
# but the speaker number in the whole recording may exceed self.n_speakers
if self.n_speakers < len(speakers):
n_speakers = len(speakers)
# Y: (length,), T: (frame_num, n_speakers)
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
# the spk index exist in this chunk data
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
chunk_spk_num = np.sum(exist_spk_idx)
if chunk_spk_num <= self.n_speakers:
spk_arr = np.array(speakers)
valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]]
for spk in valid_spk_arr:
occur_spk_set.add(spk)
new_chunk_indices.append((rec, st, ed))
self.chunk_indices = new_chunk_indices
self.all_speakers = sorted(list(occur_spk_set))
self.all_n_speakers = len(self.all_speakers)
def __getitem__(self, i):
if self.mode != "test":
rec, st, ed = self.chunk_indices[i]
filtered_segments = self.data.segments[rec]
# all the speakers in this recording not the chunk
speakers = np.unique(
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
).tolist()
n_speakers = self.n_speakers
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
# but the speaker number in the whole recording may exceed self.n_speakers
if self.n_speakers < len(speakers):
n_speakers = len(speakers)
# Y: (length,), T: (frame_num, n_speakers)
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
# the spk index exist in this chunk data
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
chunk_spk_num = np.sum(exist_spk_idx)
if chunk_spk_num > self.n_speakers:
# the speaker number in a chunk exceed our pre-set value
return None, None, None
# the map from within recording speaker index to global speaker index
S_arr = -1 * np.ones(n_speakers).astype(np.int64)
for seg in filtered_segments:
speaker_index = speakers.index(self.data.utt2spk[seg['utt']])
try:
all_speaker_index = self.all_speakers.index(
self.data.utt2spk[seg['utt']])
except:
# we have pre-filter some spk in self.filter_spk
all_speaker_index = -1
S_arr[speaker_index] = all_speaker_index
# If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1,
# so S_arr[n_speakers - 1] is not used for training,
# e.g., in the case of training 3-spk model with 2-spk data
# filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs
T_exist = T[:,exist_spk_idx]
T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32)
T[:,:T_exist.shape[1]] = T_exist
# subsampling for Y will be done in the model forward function
T = T[::self.subsampling]
S_arr_exist = S_arr[exist_spk_idx]
S_arr = -1 * np.ones(self.n_speakers).astype(np.int64)
S_arr[:S_arr_exist.shape[0]] = S_arr_exist
n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1)
return Y, T, S_arr, n, T.shape[0]
else:
len_ratio = self.frame_shift * self.subsampling
chunks = self.chunk_indices[self.rec_list[i]]
Ys = []
chunk_len_list = []
for (rec, st, ed) in chunks:
chunk_len = ed - st
if chunk_len != self.chunk_size:
st = max(0, ed - self.chunk_size)
Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio)
Ys.append(Y)
chunk_len_list.append(chunk_len)
return Ys, self.rec_list[i], chunk_len_list
def get_allnspk(self):
return self.all_n_speakers
def _get_labeled_speech(
self, rec, start, end, n_speakers=None, use_speaker_id=False
):
"""Extracts speech chunks and corresponding labels
Extracts speech chunks and corresponding diarization labels for
given recording id and start/end times
Args:
rec (str): recording id
start (int): start frame index
end (int): end frame index
n_speakers (int): number of speakers
if None, the value is given from data
Returns:
data: speech chunk
(n_samples)
T: label
(n_frmaes, n_speakers)-shaped np.int32 array.
"""
data, rate = self.data.load_wav(
rec, start * self.frame_shift, end * self.frame_shift
)
frame_num = end - start
filtered_segments = self.data.segments[rec]
# filtered_segments = self.data.segments[self.data.segments['rec'] == rec]
speakers = np.unique(
[self.data.utt2spk[seg["utt"]] for seg in filtered_segments]
).tolist()
if n_speakers is None:
n_speakers = len(speakers)
T = np.zeros((frame_num, n_speakers), dtype=np.int32)
if use_speaker_id:
all_speakers = sorted(self.data.spk2utt.keys())
S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32)
for seg in filtered_segments:
speaker_index = speakers.index(self.data.utt2spk[seg["utt"]])
if use_speaker_id:
all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]])
start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int)
end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
if start < end_frame and end_frame <= end:
rel_end = end_frame - start
if rel_start is not None or rel_end is not None:
T[rel_start:rel_end, speaker_index] = 1
if use_speaker_id:
S[rel_start:rel_end, all_speaker_index] = 1
if use_speaker_id:
return data, T, S
else:
return data, T
def collate_fn(self, batch):
valid_samples = [sample for sample in batch if sample[0] is not None]
wav_list, binary_label_list, spk_label_list= [], [], []
all_spk_idx_list, len_list = [], []
for sample in valid_samples:
wav_list.append(torch.from_numpy(sample[0]).float())
binary_label_list.append(torch.from_numpy(sample[1]).long())
spk_label_list.append(torch.from_numpy(sample[2]).long())
all_spk_idx_list.append(torch.from_numpy(sample[3]).long())
len_list.append(sample[4])
wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0)
binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long()
spk_label_batch = torch.stack(spk_label_list)
all_spk_idx_batch = torch.stack(all_spk_idx_list)
len_batch = torch.LongTensor(len_list)
return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch
def collate_fn_infer(self, batch):
assert len(batch) == 1 # each batch should contain one recording
Ys, rec, chunk_len_list = batch[0]
wav_list = [torch.from_numpy(Y).float() for Y in Ys]
return wav_list, rec, chunk_len_list
#######################
# Kaldi-style Dataset #
#######################
class KaldiData:
"""This class holds data in kaldi-style directory."""
def __init__(self, data_dir):
"""Load kaldi data directory."""
self.data_dir = data_dir
self.segments = self._load_segments_rechash(
os.path.join(self.data_dir, "segments")
)
self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
def load_wav(self, recid, start=0, end=None):
"""Load wavfile given recid, start time and end time."""
data, rate = self._load_wav(self.wavs[recid], start, end)
return data, rate
def _load_segments(self, segments_file):
"""Load segments file as array."""
if not os.path.exists(segments_file):
return None
return np.loadtxt(
segments_file,
dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
ndmin=1,
)
def _load_segments_hash(self, segments_file):
"""Load segments file as dict with uttid index."""
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
ret[utt] = (rec, float(st), float(et))
return ret
def _load_segments_rechash(self, segments_file):
"""Load segments file as dict with recid index."""
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
if rec not in ret:
ret[rec] = []
ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
return ret
def _load_wav_scp(self, wav_scp_file):
"""Return dictionary { rec: wav_rxfilename }."""
if os.path.exists(wav_scp_file):
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
return {x[0]: x[1] for x in lines}
else:
wav_dir = os.path.join(self.data_dir, "wav")
return {
os.path.splitext(filename)[0]: os.path.join(wav_dir, filename)
for filename in sorted(os.listdir(wav_dir))
}
def _load_wav(self, wav_rxfilename, start=0, end=None):
"""This function reads audio file and return data in numpy.float32 array.
"lru_cache" holds recently loaded audio so that can be called
many times on the same audio file.
OPTIMIZE: controls lru_cache size for random access,
considering memory size
"""
if wav_rxfilename.endswith("|"):
# input piped command
p = subprocess.Popen(
wav_rxfilename[:-1],
shell=True,
stdout=subprocess.PIPE,
)
data, samplerate = sf.read(
io.BytesIO(p.stdout.read()),
dtype="float32",
)
# cannot seek
data = data[start:end]
elif wav_rxfilename == "-":
# stdin
data, samplerate = sf.read(sys.stdin, dtype="float32")
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
return data, samplerate
def _load_utt2spk(self, utt2spk_file):
"""Returns dictionary { uttid: spkid }."""
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
return {x[0]: x[1] for x in lines}
def _load_spk2utt(self, spk2utt_file):
"""Returns dictionary { spkid: list of uttids }."""
if not os.path.exists(spk2utt_file):
return None
lines = [line.strip().split() for line in open(spk2utt_file)]
return {x[0]: x[1:] for x in lines}
def _load_reco2dur(self, reco2dur_file):
"""Returns dictionary { recid: duration }."""
if not os.path.exists(reco2dur_file):
return None
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
return {x[0]: float(x[1]) for x in lines}
def _process_wav(self, wav_rxfilename, process):
"""This function returns preprocessed wav_rxfilename.
Args:
wav_rxfilename:
input
process:
command which can be connected via pipe, use stdin and stdout
Returns:
wav_rxfilename: output piped command
"""
if wav_rxfilename.endswith("|"):
# input piped command
return wav_rxfilename + process + "|"
# stdin "-" or normal file
return "cat {0} | {1} |".format(wav_rxfilename, process)
def _extract_segments(self, wavs, segments=None):
"""This function returns generator of segmented audio.
Yields (utterance id, numpy.float32 array).
TODO?: sampling rate is not converted.
"""
if segments is not None:
# segments should be sorted by rec-id
for seg in segments:
wav = wavs[seg["rec"]]
data, samplerate = self.load_wav(wav)
st_sample = np.rint(seg["st"] * samplerate).astype(int)
et_sample = np.rint(seg["et"] * samplerate).astype(int)
yield seg["utt"], data[st_sample:et_sample]
else:
# segments file not found,
# wav.scp is used as segmented audio list
for rec in wavs:
data, samplerate = self.load_wav(wavs[rec])
yield rec, data
if __name__ == "__main__":
args = {
'mode': 'train',
'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500",
'chunk_size': 2001,
'frame_shift': 256,
'sampling_rate': 8000,
'num_speakers':3
}
torch.manual_seed(6)
dataset = DiarizationDataset(**args)
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)
data_iter = iter(dataloader)
# wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter)
data = next(data_iter)
for val in data:
print(val.shape)
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer)
# data_iter = iter(dataloader)
# wav_list, binary_label_list, rec = next(data_iter)
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This library provides utilities for kaldi-style data directory.
from __future__ import print_function
import os
import sys
import numpy as np
import subprocess
import soundfile as sf
import io
from functools import lru_cache
def load_segments(segments_file):
""" load segments file as array """
if not os.path.exists(segments_file):
return None
return np.loadtxt(
segments_file,
dtype=[('utt', 'object'),
('rec', 'object'),
('st', 'f'),
('et', 'f')],
ndmin=1)
def load_segments_hash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
ret[utt] = (rec, float(st), float(et))
return ret
def load_segments_rechash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
if rec not in ret:
ret[rec] = []
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
return ret
def load_wav_scp(wav_scp_file):
""" return dictionary { rec: wav_rxfilename } """
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
return {x[0]: x[1] for x in lines}
@lru_cache(maxsize=1)
def load_wav(wav_rxfilename, start=0, end=None):
""" This function reads audio file and return data in numpy.float32 array.
"lru_cache" holds recently loaded audio so that can be called
many times on the same audio file.
OPTIMIZE: controls lru_cache size for random access,
considering memory size
"""
if wav_rxfilename.endswith('|'):
# input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
dtype='float32')
# cannot seek
data = data[start:end]
elif wav_rxfilename == '-':
# stdin
data, samplerate = sf.read(sys.stdin, dtype='float32')
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
return data, samplerate
def load_utt2spk(utt2spk_file):
""" returns dictionary { uttid: spkid } """
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
return {x[0]: x[1] for x in lines}
def load_spk2utt(spk2utt_file):
""" returns dictionary { spkid: list of uttids } """
if not os.path.exists(spk2utt_file):
return None
lines = [line.strip().split() for line in open(spk2utt_file)]
return {x[0]: x[1:] for x in lines}
def load_reco2dur(reco2dur_file):
""" returns dictionary { recid: duration } """
if not os.path.exists(reco2dur_file):
return None
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
return {x[0]: float(x[1]) for x in lines}
def process_wav(wav_rxfilename, process):
""" This function returns preprocessed wav_rxfilename
Args:
wav_rxfilename: input
process: command which can be connected via pipe,
use stdin and stdout
Returns:
wav_rxfilename: output piped command
"""
if wav_rxfilename.endswith('|'):
# input piped command
return wav_rxfilename + process + "|"
else:
# stdin "-" or normal file
return "cat {} | {} |".format(wav_rxfilename, process)
def extract_segments(wavs, segments=None):
""" This function returns generator of segmented audio as
(utterance id, numpy.float32 array)
TODO?: sampling rate is not converted.
"""
if segments is not None:
# segments should be sorted by rec-id
for seg in segments:
wav = wavs[seg['rec']]
data, samplerate = load_wav(wav)
st_sample = np.rint(seg['st'] * samplerate).astype(int)
et_sample = np.rint(seg['et'] * samplerate).astype(int)
yield seg['utt'], data[st_sample:et_sample]
else:
# segments file not found,
# wav.scp is used as segmented audio list
for rec in wavs:
data, samplerate = load_wav(wavs[rec])
yield rec, data
class KaldiData:
def __init__(self, data_dir):
self.data_dir = data_dir
self.segments = load_segments_rechash(
os.path.join(self.data_dir, 'segments'))
self.utt2spk = load_utt2spk(
os.path.join(self.data_dir, 'utt2spk'))
self.wavs = load_wav_scp(
os.path.join(self.data_dir, 'wav.scp'))
self.reco2dur = load_reco2dur(
os.path.join(self.data_dir, 'reco2dur'))
self.spk2utt = load_spk2utt(
os.path.join(self.data_dir, 'spk2utt'))
def load_wav(self, recid, start=0, end=None):
data, rate = load_wav(
self.wavs[recid], start, end)
return data, rate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment