Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
import os
import struct
import logging
import torch
import math
import numpy as np
import random
import yaml
import torch.distributed as dist
import torch.nn.functional as F
# ------------------------------ Logger ------------------------------
# log to console or a file
def get_logger(
name,
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
date_format="%Y-%m-%d %H:%M:%S",
file=False):
"""
Get python logger instance
"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# file or console
handler = logging.StreamHandler() if not file else logging.FileHandler(
name)
handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# log to concole and file at the same time
def get_logger_2(
name,
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
date_format="%Y-%m-%d %H:%M:%S"):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# Create handlers
c_handler = logging.StreamHandler()
f_handler = logging.FileHandler(name)
c_handler.setLevel(logging.INFO)
f_handler.setLevel(logging.INFO)
# Create formatters and add it to handlers
c_format = logging.Formatter(fmt=format_str, datefmt=date_format)
f_format = logging.Formatter(fmt=format_str, datefmt=date_format)
c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)
# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)
return logger
# ------------------------------ Logger ------------------------------
# ------------------------------ Pytorch Distributed Training ------------------------------
def getoneNode():
nodelist = os.environ['SLURM_JOB_NODELIST']
nodelist = nodelist.strip().split(',')[0]
import re
text = re.split('[-\[\]]', nodelist)
if ('' in text):
text.remove('')
return text[0] + '-' + text[1] + '-' + text[2]
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
dist.init_process_group("nccl", init_method=host_addr_full,
rank=rank, world_size=world_size)
num_gpus = torch.cuda.device_count()
# torch.cuda.set_device(local_rank)
assert dist.is_initialized()
def cleanup():
dist.destroy_process_group()
def average_gradients(model, world_size):
size = float(world_size)
for param in model.parameters():
if (param.requires_grad and param.grad is not None):
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def data_reduce(data):
dist.all_reduce(data, op=dist.ReduceOp.SUM)
return data / torch.distributed.get_world_size()
# ------------------------------ Pytorch Distributed Training ------------------------------
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0):
current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter):
current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
return current_lr
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
# ---------------------- About Configuration --------------------
def parse_config_or_kwargs(config_file, **kwargs):
with open(config_file) as con_read:
yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
# passed kwargs will override yaml config
return dict(yaml_config, **kwargs)
def store_yaml(config_file, store_path, **kwargs):
with open(config_file, 'r') as f:
config_lines = f.readlines()
keys_list = list(kwargs.keys())
with open(store_path, 'w') as f:
for line in config_lines:
if ':' in line and line.split(':')[0] in keys_list:
key = line.split(':')[0]
line = '{}: {}\n'.format(key, kwargs[key])
f.write(line)
# ---------------------- About Configuration --------------------
def check_dir(dir):
if not os.path.exists(dir):
os.mkdir(dir)
def set_seed(seed=66):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# when store the model wrongly with "module" involved,
# we remove it here
def correct_key(state_dict):
keys = list(state_dict.keys())
if 'module' not in keys[0]:
return state_dict
else:
new_state_dict = {}
for key in keys:
new_key = '.'.join(key.split('.')[1:])
new_state_dict[new_key] = state_dict[key]
return new_state_dict
def validate_path(dir_name):
"""
:param dir_name: Create the directory if it doesn't exist
:return: None
"""
dir_name = os.path.dirname(dir_name) # get the path
if not os.path.exists(dir_name) and (dir_name != ''):
os.makedirs(dir_name)
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
## Pre-training Representations for Speaker Verification
### Pre-trained models
| Model | Fix pre-train | Vox1-O | Vox1-E | Vox1-H |
| ------------------------------------------------------------ | ------------- | --------- | --------- | -------- |
| [ECAPA-TDNN](https://drive.google.com/file/d/1kWmLyTGkBExTdxtwmrXoP4DhWz_7ZAv3/view?usp=sharing) | - | 1.080 | 1.200 | 2.127 |
| [HuBERT large](https://drive.google.com/file/d/1njofuGpidjy_jdbq7rIbQMIDyyPLoAjb/view?usp=sharing) | Yes | 0.888 | 0.912 | 1.853 |
| [Wav2Vec2.0 (XLSR)](https://drive.google.com/file/d/1izV48ebxs6re252ELiksuk6-RQov-gvE/view?usp=sharing) | Yes | 0.915 | 0.945 | 1.895 |
| [UniSpeech-SAT large](https://drive.google.com/file/d/1sOhutb3XG7_OKQIztqjePDtRMrxjOdSf/view?usp=sharing) | Yes | 0.771 | 0.781 | 1.669 |
| [WavLM Base](https://drive.google.com/file/d/1qVKHG7OzltELgkoAdFT1xXzu_hHXj3e8/view?usp=sharing) | Yes | 0.84 | 0.928 | 1.758 |
| [**WavLM large**](https://drive.google.com/file/d/1D-dPa5H6Y2ctb4SJ5n21kRkdR6t0-awD/view?usp=sharing) | Yes | 0.75 | 0.764 | 1.548 |
| [HuBERT large](https://drive.google.com/file/d/1nit9Z6RyM8Sdb3n8ccaglOQVNnqsjnui/view?usp=sharing) | No | 0.585 | 0.654 | 1.342 |
| [Wav2Vec2.0 (XLSR)](https://drive.google.com/file/d/1TgKro9pp197TCgIF__IlE_rMVQOk50Eb/view?usp=sharing) | No | 0.564 | 0.605 | 1.23 |
| [UniSpeech-SAT large](https://drive.google.com/file/d/10o6NHZsPXJn2k8n57e8Z_FkKh3V4TC3g/view?usp=sharing) | No | 0.564 | 0.561 | 1.23 |
| [**WavLM large**](https://drive.google.com/file/d/18rekjal9NPo0VquVtali-80yy63252RX/view?usp=sharing) | No | **0.431** | **0.538** | **1.154** |
### How to use?
#### Environment Setup
1. `pip install --require-hashes -r requirements.txt`
2. Install fairseq code
- For HuBERT_Large and Wav2Vec2.0 (XLSR), we should install the official [fairseq](https://github.com/pytorch/fairseq).
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
- For WavLM, we should install the latest s3prl: `pip install s3prl@git+https://github.com/s3prl/s3prl.git@7ab62aaf2606d83da6c71ee74e7d16e0979edbc3#egg=s3prl`
#### Example
Take `unispeech_sat ` and `ecapa_tdnn` for example:
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
2. Then, run the following codes:
- The wav files are sampled from [voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html).
```bash
python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
# output: The similarity score between two audios is 0.0317 (-1.0, 1.0).
python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
# output: The similarity score between two audios is 0.5389 (-1.0, 1.0).
python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
# output: The similarity score between two audios is 0.2053 (-1.0, 1.0).
python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
# output: he similarity score between two audios is 0.5302 (-1.0, 1.0).
```
#### Example 2
```bash
git clone https://github.com/Sanyuan-Chen/UniSpeech.git -b t-schen/asv_eval
cd UniSpeech/downstreams/speaker_verification
pip install scipy==1.7.1 fire==0.4.0 sklearn==0.0 s3prl==0.3.1 torchaudio==0.9.0 sentencepiece==0.1.96
pip install s3prl@git+https://github.com/s3prl/s3prl.git@7ab62aaf2606d83da6c71ee74e7d16e0979edbc3#egg=s3prl
wget "https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/sv_finetuned/wavlm_large_finetune.pth?sv=2020-08-04&st=2022-12-02T09%3A48%3A45Z&se=2024-12-03T09%3A48%3A00Z&sr=b&sp=r&sig=jQPnEO9I5JqtoWylCvHIU0IvUxZ8jzC%2F64%2B6%2Fa1%2FKE4%3D" -O wavlm_large_finetune.pth
python verification_tsv.py $tsv1 $tsv2 --model_name wavlm_large --checkpoint wavlm_large_finetune.pth --scores $score_file --wav1_start_sr 0 --wav2_start_sr 0 --wav1_end_sr -1 --wav2_end_sr -1
```
If an error in hubconf.py raised, replace the file with utils/hubconf.py
tsv file example
```bash
root_path
wav1
wav2
...
```
import sys
import numpy as np
infile=sys.argv[1]
outfile=sys.argv[2]
fout = open(outfile, "w")
scores = []
for line in open(infile, "r").readlines():
item, score = line.strip().split("\t")
scores.append(float(score))
fout.write(line)
res = round(np.mean(np.array(scores)), 3)
res_var = round(np.var(np.array(scores)), 3)
fout.write(f"ASV: {res}\n")
fout.write(f"ASV-var: {res_var}\n")
print(f"ASV: {res}")
print(f"ASV-var: {res_var}")
fout.close()
import argparse
from verification import extract_embedding
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
parser = argparse.ArgumentParser()
parser.add_argument("--infile", default="/mnt/bn/jcong5/pretrain_data/process_rp_prompt/thread-00.lst")
parser.add_argument("--outdir", default="/mnt/bn/jcong5/pretrain_data/rp_prompt_embedding/")
parser.add_argument('--checkpoint', default="/mnt/bn/jcong5/workspace/bigtts-eval/.cache_dir/wavlm_large_finetune.pth")
parser.add_argument('--device', default="cuda:0")
args = parser.parse_args()
model = None
lines = open(args.infile, "r").readlines()
os.makedirs(args.outdir, exist_ok=True)
for i, item in enumerate(tqdm(lines)):
wavpath = item.strip()
utt = os.path.splitext(os.path.basename(wavpath))[0]
output_path = os.path.join(args.outdir, utt+".npy")
print(output_path)
if os.path.exists(output_path):
print("skip")
continue
sim, model = extract_embedding("wavlm_large",
wavpath,
use_gpu=True,
checkpoint=args.checkpoint,
model=model,
device=args.device)
np.save(os.path.join(args.outdir, utt), sim[0].cpu())
\ No newline at end of file
'''
python extract_spks_from_score_file.py --scores $score_file --spks +0.9.12.18.21.34.64.71.87.92.99
python extract_spks_from_score_file.py --scores $score_file --spks +p225.p234.p238.p245.p248.p261.p294.p302.p326.p335.p347
'''
import tqdm
import argparse
from verification import verification
parser = argparse.ArgumentParser()
parser.add_argument('--scores')
parser.add_argument('--spks')
args = parser.parse_args()
scores = open(args.scores)
scores_w = open(args.scores+'.'+args.spks, 'w')
is_in = args.spks[0] == '+'
spks = args.spks[1:].split('.')
#spks = [int(i) for i in args.spks.split('.')]
score_list = []
for line in scores:
if (is_in and line.split('_')[0].split('/')[0] in spks) or (not is_in and line.split('_')[0].split('/')[0] not in spks):
scores_w.write(line)
score_list.append(float(line.split()[-1]))
print(f'avg score: {sum(score_list)}/{len(score_list)}={sum(score_list)/len(score_list)}')
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as trans
# from .utils import UpstreamExpert
''' Res2Conv1d + BatchNorm1d + ReLU
'''
class Res2Conv1dReluBn(nn.Module):
'''
in_channels == out_channels == channels
'''
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
''' Conv1d + BatchNorm1d + ReLU
'''
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
''' The SE connection of 1D case.
'''
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
''' SE-Res2Block of the ECAPA-TDNN architecture.
'''
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
''' Attentive weighted mean and standard deviation pooling.
'''
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
if feat_type == "fbank" or feat_type == "mfcc":
self.update_extract = False
win_len = int(sr * 0.025)
hop_len = int(sr * 0.01)
if feat_type == 'fbank':
self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
hop_length=hop_len, f_min=0.0, f_max=sr // 2,
pad=0, n_mels=feat_dim)
elif feat_type == 'mfcc':
melkwargs = {
'n_fft': 512,
'win_length': win_len,
'hop_length': hop_len,
'f_min': 0.0,
'f_max': sr // 2,
'pad': 0
}
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
melkwargs=melkwargs)
else:
if config_path is None:
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
else:
self.feature_extract = UpstreamExpert(config_path)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == 'fbank':
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
if __name__ == '__main__':
x = torch.zeros(2, 32000)
model = ECAPA_TDNN_SMALL(feat_dim=768, emb_dim=256, feat_type='hubert_base', feature_selection="hidden_states",
update_extract=False)
out = model(x)
# print(model)
print(out.shape)
import torch
import fairseq
from packaging import version
import torch.nn.functional as F
from fairseq import tasks
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import OmegaConf
from s3prl.upstream.interfaces import UpstreamBase
from torch.nn.utils.rnn import pad_sequence
def load_model(filepath):
state = torch.load(filepath, map_location=lambda storage, loc: storage)
# state = load_checkpoint_to_cpu(filepath)
state["cfg"] = OmegaConf.create(state["cfg"])
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
model = task.build_model(cfg.model)
return model, cfg, task
###################
# UPSTREAM EXPERT #
###################
class UpstreamExpert(UpstreamBase):
def __init__(self, ckpt, **kwargs):
super().__init__(**kwargs)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = load_model(ckpt)
self.model = model
self.task = task
if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
self.add_hook(
f"{module_name}[{module_id}]",
lambda input, output: input[0].transpose(0, 1),
)
self.add_hook("self.model.encoder", lambda input, output: output[0])
def forward(self, wavs):
if self.task.cfg.normalize:
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
return {
"default": features,
}
scipy==1.7.1 \
--hash=sha256:2a0eeaab01258e0870c4022a6cd329aef3b7c6c2b606bd7cf7bb2ba9820ae561 \
--hash=sha256:3304bd5bc32e00954ac4b3f4cc382ca8824719bf348aacbec6347337d6b125fe \
--hash=sha256:3f52470e0548cdb74fb8ddf06773ffdcca7c97550f903b1c51312ec19243a7f7 \
--hash=sha256:4729b41a4cdaf4cd011aeac816b532f990bdf97710cef59149d3e293115cf467 \
--hash=sha256:4ee952f39a4a4c7ba775a32b664b1f4b74818548b65f765987adc14bb78f5802 \
--hash=sha256:611f9cb459d0707dd8e4de0c96f86e93f61aac7475fcb225e9ec71fecdc5cebf \
--hash=sha256:6b47d5fa7ea651054362561a28b1ccc8da9368a39514c1bbf6c0977a1c376764 \
--hash=sha256:71cfc96297617eab911e22216e8a8597703202e95636d9406df9af5c2ac99a2b \
--hash=sha256:787749110a23502031fb1643c55a2236c99c6b989cca703ea2114d65e21728ef \
--hash=sha256:90c07ba5f34f33299a428b0d4fa24c30d2ceba44d63f8385b2b05be460819fcb \
--hash=sha256:a496b42dbcd04ea9924f5e92be63af3d8e0f43a274b769bfaca0a297327d54ee \
--hash=sha256:bc61e3e5ff92d2f32bb263621d54a9cff5e3f7c420af3d1fa122ce2529de2bd9 \
--hash=sha256:c9951e3746b68974125e5e3445008a4163dd6d20ae0bbdae22b38cb8951dc11b \
--hash=sha256:d1388fbac9dd591ea630da75c455f4cc637a7ca5ecb31a6b6cef430914749cde \
--hash=sha256:d13f31457f2216e5705304d9f28e2826edf75487410a57aa99263fa4ffd792c2 \
--hash=sha256:d648aa85dd5074b1ed83008ae987c3fbb53d68af619fce1dee231f4d8bd40e2f \
--hash=sha256:da9c6b336e540def0b7fd65603da8abeb306c5fc9a5f4238665cbbb5ff95cf58 \
--hash=sha256:e101bceeb9e65a90dadbc5ca31283403a2d4667b9c178db29109750568e8d112 \
--hash=sha256:efdd3825d54c58df2cc394366ca4b9166cf940a0ebddeb87b6c10053deb625ea
fire==0.4.0 \
--hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
sklearn==0.0 \
--hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
s3prl==0.3.1 \
--hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
torchaudio==0.9.0 \
--hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
--hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
--hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
--hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
--hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
--hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
--hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
--hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
--hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
--hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
--hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
--hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
--hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
--hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
sentencepiece==0.1.96 \
--hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
--hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
--hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
--hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
--hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
--hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
--hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
--hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
--hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
--hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
--hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
--hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
--hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
--hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
--hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
--hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
--hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
--hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
--hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
--hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
--hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
--hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
--hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
--hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
--hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
--hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
--hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
--hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
--hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
--hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
--hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
--hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
--hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
--hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
--hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
--hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
--hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
--hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
--hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
--hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
--hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
--hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
export https_proxy=http://bj-rd-proxy.byted.org:3128 http_proxy=http://bj-rd-proxy.byted.org:3128 no_proxy=code.byted.org
thread_dir=/mnt/bn/jcong5/pretrain_data/process_rp_prompt/
# num_job=4
# for rank in $(seq 1 $((num_job - 1))); do
# echo $rank
# python3 extract_embedding.py \
# --infile $thread_dir/thread-0$rank.lst \
# --device cuda:$rank &
# done
num_job=8
for rank in $(seq 0 $((num_job - 1))); do
echo $rank
part=`expr $rank + 4`
padded=$(printf "%02d\n" $part)
echo thread-$padded.lst
python3 extract_embedding.py \
--infile $thread_dir/thread-$padded.lst \
--device cuda:$rank &
done
\ No newline at end of file
import argparse
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
# 生成随机样本数据
# 对样本数据进行 T-SNE 降维
# 根据 T-SNE 降维结果画图
parser = argparse.ArgumentParser()
parser.add_argument("--indir", default="/mnt/bn/jcong5/pretrain_data/rp_prompt_embedding/")
# parser.add_argument("--indir", default="/mnt/bn/jcong5/pretrain_data/process_rp_prompt/threshold-0.5/uniq_embeddings")
parser.add_argument('--checkpoint', default="/mnt/bn/jcong5/workspace/bigtts-eval/.cache_dir/wavlm_large_finetune.pth")
parser.add_argument('--output_dir', default="/mnt/bn/jcong5/pretrain_data/process_rp_prompt/")
parser.add_argument('--threshold', type=float, default=0.5)
args = parser.parse_args()
threshold=args.threshold
output_dir = os.path.join(args.output_dir, f"threshold-{threshold}")
os.makedirs(output_dir, exist_ok=True)
outfile=os.path.join(output_dir, "uniq.lst")
fout=open(outfile, "w")
repeat_out=open(os.path.join(output_dir, "repeat.lst"), "w")
model = None
sims = []
id2utt = {}
i=0
for item in tqdm(os.listdir(args.indir)):
wavpath = os.path.join(args.indir, item)
try:
sim = np.load(wavpath, allow_pickle=True)
except:
print(f"skip-{wavpath}")
continue
sims.append(sim)
id2utt[i] = wavpath
i=i+1
sims = np.array(sims)
sims = cosine_similarity(sims)
indices = np.arange(sim.shape[0])
delete = []
delete_pair = defaultdict(list)
for i in range(sims.shape[0]):
if i in delete:
continue
for j in range(i+1, sims.shape[0]):
if sims[i, j] >= threshold:
delete.append(j)
delete_pair[i].append(j)
for i in range(sims.shape[0]):
if i in delete:
continue
fout.write(id2utt[i]+"\n")
for k, v in delete_pair.items():
repeat_out.write(id2utt[k]+"\n")
for i in v:
repeat_out.write(id2utt[i]+"\n")
repeat_out.write("\n")
smis = np.array(sims)
# 对样本数据进行 T-SNE 降维
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200)
X_tsne = tsne.fit_transform(smis)
# 根据 T-SNE 降维结果画图
plt.scatter(X_tsne[:, 0], X_tsne[:, 1])
plt.savefig("test2.png")
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ upstream/wavlm/hubconf.py ]
# Synopsis [ the WavLM torch hubconf ]
# Author [ Microsoft ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
# -------------#
from s3prl.utility.download import _urls_to_filepaths
from .expert import UpstreamExpert as _UpstreamExpert
def wavlm_local(ckpt, *args, **kwargs):
"""
The model from local ckpt
ckpt (str): PATH
"""
assert os.path.isfile(ckpt)
return _UpstreamExpert(ckpt, *args, **kwargs)
def wavlm_url(ckpt, refresh=False, agent="wget", *args, **kwargs):
"""
The model from google drive id
ckpt (str): URL
refresh (bool): whether to download ckpt/config again if existed
"""
return wavlm_local(
_urls_to_filepaths(ckpt, refresh=refresh, agent=agent), *args, **kwargs
)
def wavlm(refresh=False, *args, **kwargs):
"""
The default model - Base-Plus
refresh (bool): whether to download ckpt/config again if existed
"""
return wavlm_base_plus(refresh=refresh, *args, **kwargs)
def wavlm_base(refresh=False, *args, **kwargs):
"""
The Base model
refresh (bool): whether to download ckpt/config again if existed
"""
# Azure Storage
kwargs["ckpt"] = "\"https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base.pt?sv=2020-04-08&st=2021-11-05T00%3A35%3A31Z&se=2022-11-06T00%3A35%3A00Z&sr=b&sp=r&sig=JljnRVzyHY6AjHzhVmHV5KyQQCvvGfgp9D2M02oGJBU%3D\""
# Google Drive
# kwargs["ckpt"] = "https://drive.google.com/u/0/uc?id=19-C7SMQvEFAYLG5uc47NX_MY03JCbI4x&export=download"
# kwargs["agent"] = 'gdown'
return wavlm_url(refresh=refresh, *args, **kwargs)
def wavlm_base_plus(refresh=False, *args, **kwargs):
"""
The Base-Plus model
refresh (bool): whether to download ckpt/config again if existed
"""
# Azure Storage
kwargs["ckpt"] = "\"https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-04-08&st=2021-11-05T00%3A34%3A47Z&se=2022-10-06T00%3A34%3A00Z&sr=b&sp=r&sig=Gkf1IByHaIn1t%2FVEd9D6WHjZ3zu%2Fk5eSdoj21UytKro%3D\""
# Google Drive
# kwargs["ckpt"] = "https://drive.google.com/u/1/uc?id=1PlbT_9_B4F9BsD_ija84sUTVw7almNX8&export=download"
# kwargs["agent"] = 'gdown'
return wavlm_url(refresh=refresh, *args, **kwargs)
def wavlm_large(refresh=False, *args, **kwargs):
"""
The Large model
refresh (bool): whether to download ckpt/config again if existed
"""
# Azure Storage
kwargs["ckpt"] = "\"https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2021-11-22T10%3A03%3A53Z&se=2022-11-23T10%3A03%3A00Z&sr=b&sp=r&sig=3kB8dwTCyIS8YQ7gW5oXmDrXV%2FAaLmoxBS37oPpFsz4%3D\""
kwargs["ckpt"] = "\"https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2022-11-11T04%3A49%3A54Z&se=2023-11-12T04%3A49%3A00Z&sr=b&sp=r&sig=7jUNrI8FgkE6amYy1ge4Lqj0JhdNUxCCA1KF89YvV8s%3D\""
# Google Drive
# kwargs["ckpt"] = "https://drive.google.com/u/1/uc?id=1p8nbj16b7YA16sqPZ4E0JUL-oIDUBGwU&export=download"
# kwargs["agent"] = 'gdown'
return wavlm_url(refresh=refresh, *args, **kwargs)
import soundfile as sf
import torch
import fire
import torch.nn.functional as F
from torchaudio.transforms import Resample
from models.ecapa_tdnn import ECAPA_TDNN_SMALL
import librosa
MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]
def init_model(model_name, checkpoint=None):
if model_name == 'unispeech_sat':
config_path = 'config/unispeech_sat.th'
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
elif model_name == 'wavlm_base_plus':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
elif model_name == 'wavlm_large':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
elif model_name == 'hubert_large':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
elif model_name == 'wav2vec2_xlsr':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
else:
model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
if checkpoint is not None:
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['model'], strict=False)
return model
def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None, wav1_start_sr=0, wav2_start_sr=0, wav1_end_sr=-1, wav2_end_sr=-1, model=None, wav2_cut_wav1=False, device="cuda:0"):
assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
model = init_model(model_name, checkpoint) if model is None else model
wav1, sr1 = librosa.load(wav1, sr=None, mono=False)
# wav1, sr1 = sf.read(wav1)
if len(wav1.shape) == 2:
wav1 = wav1[:,0]
# wav2, sr2 = sf.read(wav2)
wav2, sr2 = librosa.load(wav2, sr=None, mono=False)
if len(wav2.shape) == 2:
wav2 = wav2[0,:] # wav2.shape: [channels, T]
wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
wav2 = torch.from_numpy(wav2).unsqueeze(0).float()
resample1 = Resample(orig_freq=sr1, new_freq=16000)
resample2 = Resample(orig_freq=sr2, new_freq=16000)
wav1 = resample1(wav1)
wav2 = resample2(wav2)
# print(f'origin wav1 sr: {wav1.shape}, wav2 sr: {wav2.shape}')
if wav2_cut_wav1:
wav2 = wav2[...,wav1.shape[-1]:]
else:
wav1 = wav1[...,wav1_start_sr:wav1_end_sr if wav1_end_sr > 0 else wav1.shape[-1]]
wav2 = wav2[...,wav2_start_sr:wav2_end_sr if wav2_end_sr > 0 else wav2.shape[-1]]
# print(f'cutted wav1 sr: {wav1.shape}, wav2 sr: {wav2.shape}')
if use_gpu:
model = model.cuda(device)
wav1 = wav1.cuda(device)
wav2 = wav2.cuda(device)
model.eval()
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)
# print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item()))
return sim, model
def extract_embedding(model_name, wav1, use_gpu=True, checkpoint=None, wav1_start_sr=0, wav1_end_sr=-1, model=None, device="cuda:0"):
assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
model = init_model(model_name, checkpoint) if model is None else model
wav1, sr1 = sf.read(wav1)
wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
resample1 = Resample(orig_freq=sr1, new_freq=16000)
wav1 = resample1(wav1)
# print(f'origin wav1 sr: {wav1.shape}, wav2 sr: {wav2.shape}')
wav1 = wav1[...,wav1_start_sr:wav1_end_sr if wav1_end_sr > 0 else wav1.shape[-1]]
if use_gpu:
model = model.cuda(device)
wav1 = wav1.cuda(device)
model.eval()
with torch.no_grad():
emb1 = model(wav1)
# print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item()))
return emb1, model
if __name__ == "__main__":
fire.Fire(verification)
wav_wav_text=$1
score_file=$2
python3 verification_pair_list_v2.py $wav_wav_text --model_name wavlm_large --checkpoint $PWD/wavlm_large_finetune.pth --scores $score_file --wav1_start_sr 0 --wav2_start_sr 0 --wav1_end_sr -1 --wav2_end_sr -1
'''
python verification_tsv.py $tsv1 $tsv2 --model_name wavlm_large --checkpoint wavlm_large_finetune.pth --scores $score_file --wav1_start_sr 0 --wav2_start_sr 0 --wav1_end_sr -1 --wav2_end_sr -1
'''
import tqdm
import argparse
from verification import verification
import os
parser = argparse.ArgumentParser()
parser.add_argument('pair')
parser.add_argument('--model_name')
parser.add_argument('--checkpoint')
parser.add_argument('--scores')
parser.add_argument('--wav1_start_sr', type=int)
parser.add_argument('--wav2_start_sr', type=int)
parser.add_argument('--wav1_end_sr', type=int)
parser.add_argument('--wav2_end_sr', type=int)
parser.add_argument('--wav2_cut_wav1', type=bool, default=False)
parser.add_argument('--device', default="cuda:0")
args = parser.parse_args()
f = open(args.pair)
lines = f.readlines()
f.close()
tsv1 = []
tsv2 = []
for line in lines:
e = line.strip().split('|')
if len(e) == 4:
part1, _, _, part2 = line.strip().split('|')
else:
part1, part2 = line.strip().split('|')[:2]
tsv1.append(part1)
tsv2.append(part2)
scores_w = open(args.scores, 'w')
assert len(tsv1) == len(tsv2)
model = None
score_list = []
for t1, t2 in tqdm.tqdm(zip(tsv1, tsv2), total=len(tsv1)):
t1_path = t1.strip()
t2_path = t2.strip()
if not os.path.exists(t1_path) or not os.path.exists(t2_path):
continue
try:
sim, model = verification(args.model_name, t1_path, t2_path, use_gpu=True, checkpoint=args.checkpoint, wav1_start_sr=args.wav1_start_sr, wav2_start_sr=args.wav2_start_sr, wav1_end_sr=args.wav1_end_sr, wav2_end_sr=args.wav2_end_sr, model=model, wav2_cut_wav1=args.wav2_cut_wav1, device=args.device)
except Exception as e:
print(str(e))
continue
if sim is None:
continue
scores_w.write(f'{t1_path}_{args.wav1_start_sr}_{args.wav1_end_sr}|{t2_path}_{args.wav2_start_sr}_{args.wav2_end_sr}\t{sim.cpu().item()}\n')
# print(f'{t1_path}_{args.wav1_start_sr}_{args.wav1_end_sr}|{t2_path}_{args.wav2_start_sr}_{args.wav2_end_sr}\t{sim.cpu().item()}')
score_list.append(sim.cpu().item())
scores_w.flush()
scores_w.write(f'avg score: {sum(score_list)/len(score_list)}')
scores_w.flush()
# print(f'avg score: {round(sum(score_list)/len(score_list), 3)}')
'''
python verification_tsv.py $tsv1 $tsv2 --model_name wavlm_large --checkpoint wavlm_large_finetune.pth --scores $score_file --wav1_start_sr 0 --wav2_start_sr 0 --wav1_end_sr -1 --wav2_end_sr -1
'''
import tqdm
import argparse
from verification import verification
parser = argparse.ArgumentParser()
parser.add_argument('tsv1')
parser.add_argument('tsv2')
parser.add_argument('--model_name')
parser.add_argument('--checkpoint')
parser.add_argument('--scores')
parser.add_argument('--wav1_start_sr', type=int)
parser.add_argument('--wav2_start_sr', type=int)
parser.add_argument('--wav1_end_sr', type=int)
parser.add_argument('--wav2_end_sr', type=int)
parser.add_argument('--wav2_cut_wav1', type=bool, default=False)
args = parser.parse_args()
tsv1 = open(args.tsv1)
tsv1_root = tsv1.readline().strip()
tsv1 = tsv1.readlines()
tsv2 = open(args.tsv2)
tsv2_root = tsv2.readline().strip()
tsv2 = tsv2.readlines()
scores_w = open(args.scores, 'w')
assert len(tsv1) == len(tsv2)
model = None
score_list = []
for t1, t2 in tqdm.tqdm(zip(tsv1, tsv2), total=len(tsv1)):
t1_name = t1.split()[0]
t2_name = t2.split()[0]
try:
print(f"processing {t1_name} {t2_name}")
sim, model = verification(args.model_name, tsv1_root+'/'+t1_name, tsv2_root+'/'+t2_name, use_gpu=True, checkpoint=args.checkpoint, wav1_start_sr=args.wav1_start_sr, wav2_start_sr=args.wav2_start_sr, wav1_end_sr=args.wav1_end_sr, wav2_end_sr=args.wav2_end_sr, model=model, wav2_cut_wav1=args.wav2_cut_wav1)
Exception e:
continue
scores_w.write(f'{t1_name}_{args.wav1_start_sr}_{args.wav1_end_sr}|{t2_name}_{args.wav2_start_sr}_{args.wav2_end_sr}\t{sim.cpu().item()}\n')
print(f'{t1_name}_{args.wav1_start_sr}_{args.wav1_end_sr}|{t2_name}_{args.wav2_start_sr}_{args.wav2_end_sr}\t{sim.cpu().item()}')
score_list.append(sim.cpu().item())
scores_w.flush()
print(f'avg score: {sum(score_list)/len(score_list)}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment