Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
4fa77623
"vscode:/vscode.git/clone" did not exist on "98c1117d00edd38d72610d6a87c0c8d706873863"
Unverified
Commit
4fa77623
authored
Oct 30, 2021
by
nateanl
Committed by
GitHub
Oct 30, 2021
Browse files
Add preprocessing scripts for HuBERT model training (#1911)
parent
207d8119
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
561 additions
and
0 deletions
+561
-0
examples/hubert/preprocess.py
examples/hubert/preprocess.py
+124
-0
examples/hubert/utils/__init__.py
examples/hubert/utils/__init__.py
+10
-0
examples/hubert/utils/common_utils.py
examples/hubert/utils/common_utils.py
+112
-0
examples/hubert/utils/feature_utils.py
examples/hubert/utils/feature_utils.py
+137
-0
examples/hubert/utils/kmeans.py
examples/hubert/utils/kmeans.py
+178
-0
No files found.
examples/hubert/preprocess.py
0 → 100644
View file @
4fa77623
#!/usr/bin/env python3
"""This is the preprocessing script for HuBERT model training.
The script includes:
- File list creation
- MFCC/HuBERT feature extraction
- KMeans clustering model training
- Pseudo-label generation
"""
import
logging
from
argparse
import
ArgumentParser
,
RawTextHelpFormatter
from
multiprocessing
import
Pool
from
pathlib
import
Path
import
torch
from
utils
import
(
create_tsv
,
dump_features
,
learn_kmeans
,
get_km_label
,
)
def
_init_logger
(
debug
=
False
):
message_fmt
=
(
"%(levelname)5s: %(funcName)10s: %(message)s"
if
debug
else
"%(message)s"
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
if
debug
else
logging
.
INFO
,
format
=
f
"%(asctime)s:
{
message_fmt
}
"
,
)
def
_parse_args
():
parser
=
ArgumentParser
(
description
=
__doc__
,
formatter_class
=
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug log"
)
parser
.
add_argument
(
"--dataset"
,
default
=
"librispeech"
,
type
=
str
,
choices
=
[
"librispeech"
,
"librilight"
])
parser
.
add_argument
(
"--root-dir"
,
type
=
Path
,
help
=
"The path to the directory where the directory ``LibriSpeech`` or ``LibriLight`` is stored."
,
)
parser
.
add_argument
(
"--num-rank"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--feat-type"
,
default
=
"mfcc"
,
type
=
str
)
parser
.
add_argument
(
"--use-gpu"
,
default
=
False
,
type
=
bool
)
parser
.
add_argument
(
"--exp-dir"
,
type
=
Path
,
help
=
"The directory to store the experiment outputs."
,
)
parser
.
add_argument
(
"--num-cluster"
,
default
=
100
,
type
=
int
,
help
=
"The number of clusters for KMeans clustering."
,
)
args
=
parser
.
parse_args
()
return
args
def
main
(
args
):
_init_logger
(
args
.
debug
)
if
not
args
.
exp_dir
.
exists
():
args
.
exp_dir
.
mkdir
()
tsv_dir
=
args
.
exp_dir
/
"tsv"
feat_dir
=
args
.
exp_dir
/
args
.
feat_type
km_dir
=
args
.
exp_dir
/
"km_model"
label_dir
=
args
.
exp_dir
/
"label"
if
args
.
use_gpu
:
device
=
torch
.
device
(
"cuda"
)
else
:
device
=
torch
.
device
(
"cpu"
)
# Create file lists for training and validation (optional)
create_tsv
(
args
.
root_dir
,
tsv_dir
)
# Extract features for KMeans clustering
if
not
feat_dir
.
exists
():
feat_dir
.
mkdir
()
for
split
in
[
"train"
,
"valid"
]:
p
=
Pool
(
args
.
num_rank
)
inputs
=
[(
tsv_dir
/
f
"
{
args
.
dataset
}
_
{
split
}
.tsv"
,
feat_dir
,
split
,
rank
,
args
.
num_rank
,
device
,
args
.
feat_type
,
16_000
,)
for
rank
in
range
(
args
.
num_rank
)
]
_
=
p
.
starmap
(
dump_features
,
inputs
)
p
.
close
()
p
.
join
()
# Fit KMeans clustering model
learn_kmeans
(
feat_dir
,
"train"
,
args
.
num_rank
,
km_dir
,
args
.
num_cluster
,
)
# Predict labels for MFCC features
for
split
in
[
"train"
,
"valid"
]:
get_km_label
(
feat_dir
,
km_dir
,
label_dir
,
split
,
args
.
num_rank
,
device
,
)
if
__name__
==
"__main__"
:
main
(
_parse_args
())
examples/hubert/utils/__init__.py
0 → 100644
View file @
4fa77623
from
.common_utils
import
create_tsv
from
.feature_utils
import
dump_features
from
.kmeans
import
learn_kmeans
,
get_km_label
__all__
=
[
"create_tsv"
,
"dump_features"
,
"learn_kmeans"
,
"get_km_label"
,
]
examples/hubert/utils/common_utils.py
0 → 100644
View file @
4fa77623
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# https://github.com/pytorch/fairseq/blob/265df7144c79446f5ea8d835bda6e727f54dad9d/LICENSE
"""
Data pre-processing: create tsv files for training (and valiation).
"""
import
logging
import
re
from
pathlib
import
Path
from
typing
import
(
Tuple
,
Union
,
)
import
torch
import
torchaudio
_LG
=
logging
.
getLogger
(
__name__
)
def
create_tsv
(
root_dir
:
Union
[
str
,
Path
],
out_dir
:
Union
[
str
,
Path
],
dataset
:
str
=
"librispeech"
,
valid_percent
:
float
=
0.01
,
seed
:
int
=
0
,
extension
:
str
=
"flac"
,
)
->
None
:
"""Create file lists for training and validation.
Args:
root_dir (str or Path): The directory of the dataset.
out_dir (str or Path): The directory to store the file lists.
dataset (str, optional): The dataset to use. Options:
[``librispeech``, ``libri-light``]. (Default: ``librispeech``)
valid_percent (float, optional): The percentage of data for validation. (Default: 0.01)
seed (int): The seed for randomly selecting the validation files.
extension (str, optional): The extention of audio files. (Default: ``flac``)
Returns:
None
"""
assert
valid_percent
>=
0
and
valid_percent
<=
1.0
torch
.
manual_seed
(
seed
)
root_dir
=
Path
(
root_dir
)
out_dir
=
Path
(
out_dir
)
if
not
out_dir
.
exists
():
out_dir
.
mkdir
()
valid_f
=
(
open
(
out_dir
/
f
"
{
dataset
}
_valid.tsv"
,
"w"
)
if
valid_percent
>
0
else
None
)
search_pattern
=
".*train.*"
with
open
(
out_dir
/
f
"
{
dataset
}
_train.tsv"
,
"w"
)
as
train_f
:
print
(
root_dir
,
file
=
train_f
)
if
valid_f
is
not
None
:
print
(
root_dir
,
file
=
valid_f
)
for
fname
in
root_dir
.
glob
(
f
"**/*.
{
extension
}
"
):
if
re
.
match
(
search_pattern
,
str
(
fname
)):
frames
=
torchaudio
.
info
(
fname
).
num_frames
dest
=
train_f
if
torch
.
rand
(
1
)
>
valid_percent
else
valid_f
print
(
f
"
{
fname
.
relative_to
(
root_dir
)
}
\t
{
frames
}
"
,
file
=
dest
)
if
valid_f
is
not
None
:
valid_f
.
close
()
_LG
.
info
(
"Finished creating the file lists successfully"
)
def
_get_feat_lens_paths
(
feat_dir
:
Path
,
split
:
str
,
rank
:
int
,
num_rank
:
int
)
->
Tuple
[
Path
,
Path
]:
r
"""Get the feature and lengths paths based on feature directory,
data split, rank, and number of ranks.
Args:
feat_dir (Path): The directory that stores the feature and lengths tensors.
split (str): The split of data. Options: [``train``, ``valid``].
rank (int): The rank in the multi-processing.
num_rank (int): The number of ranks for multi-processing in feature extraction.
Returns:
(Path, Path)
Path: The file path of the feature tensor for the current rank.
Path: The file path of the lengths tensor for the current rank.
"""
feat_path
=
feat_dir
/
f
"
{
split
}
_
{
rank
}
_
{
num_rank
}
.pt"
len_path
=
feat_dir
/
f
"len_
{
split
}
_
{
rank
}
_
{
num_rank
}
.pt"
return
feat_path
,
len_path
def
_get_model_path
(
km_dir
:
Path
)
->
Path
:
r
"""Get the file path of the KMeans clustering model
Args:
km_dir (Path): The directory to store the KMeans clustering model.
Returns:
Path: The file path of the model.
"""
return
km_dir
/
"model.pt"
examples/hubert/utils/feature_utils.py
0 → 100644
View file @
4fa77623
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# https://github.com/pytorch/fairseq/blob/265df7144c79446f5ea8d835bda6e727f54dad9d/LICENSE
import
logging
from
pathlib
import
Path
from
typing
import
(
Tuple
,
Union
,
)
import
torch
import
torchaudio
from
torch
import
Tensor
from
.common_utils
import
_get_feat_lens_paths
_LG
=
logging
.
getLogger
(
__name__
)
def
get_shard_range
(
num_lines
:
int
,
num_rank
:
int
,
rank
:
int
)
->
Tuple
[
int
,
int
]:
r
"""Get the range of indices for the current rank in multi-processing.
Args:
num_lines (int): The number of lines to process.
num_rank (int): The number of ranks for multi-processing in feature extraction.
rank (int): The rank in the multi-processing.
Returns:
(int, int):
int: The start index for the current rank.
int: The end index for the current rank.
"""
assert
0
<=
rank
<
num_rank
,
f
"invalid rank/num_rank
{
rank
}
/
{
num_rank
}
"
assert
num_lines
>
0
,
f
"Found
{
num_lines
}
files, make sure you specify the correct root directory"
start
=
round
(
num_lines
/
num_rank
*
rank
)
end
=
round
(
num_lines
/
num_rank
*
(
rank
+
1
))
_LG
.
info
(
f
"rank
{
rank
}
of
{
num_rank
}
, process
{
end
-
start
}
"
f
"(
{
start
}
-
{
end
}
) out of
{
num_lines
}
"
)
return
start
,
end
def
extract_feature
(
path
:
str
,
device
:
torch
.
device
,
feature_type
:
str
,
sample_rate
:
int
,
)
->
Tensor
:
r
"""Extract features for KMeans clustering and pseudo label prediction.
Args:
path (str): The file path of the audio.
device (torch.device): The location to allocate for PyTorch Tensors.
Options: [``torch.device('cpu')``, torch.device('cuda')``].
feature_type (str): The type of the desired feature. Options: [``mfcc``, ``hubert``].
sample_rate (int): The sample rate of the audio.
Returns:
Tensor: The desired feature tensor of the given audio file.
"""
waveform
,
sr
=
torchaudio
.
load
(
path
)
assert
sr
==
sample_rate
waveform
=
waveform
[
0
].
to
(
device
)
if
feature_type
==
"mfcc"
:
feature_extractor
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
).
to
(
device
)
mfccs
=
feature_extractor
(
waveform
)
# (freq, time)
# mfccs = torchaudio.compliance.kaldi.mfcc(
# waveform=waveform,
# sample_frequency=sample_rate,
# use_energy=False,
# ) # (time, freq)
# mfccs = mfccs.transpose(0, 1) # (freq, time)
deltas
=
torchaudio
.
functional
.
compute_deltas
(
mfccs
)
ddeltas
=
torchaudio
.
functional
.
compute_deltas
(
deltas
)
concat
=
torch
.
cat
([
mfccs
,
deltas
,
ddeltas
],
dim
=
0
)
concat
=
concat
.
transpose
(
0
,
1
)
# (time, freq)
return
concat
def
dump_features
(
tsv_file
:
Union
[
str
,
Path
],
out_dir
:
Union
[
str
,
Path
],
split
:
str
,
rank
:
int
,
num_rank
:
int
,
device
:
torch
.
device
,
feature_type
:
str
=
"mfcc"
,
sample_rate
:
int
=
16_000
,
)
->
None
:
r
"""Dump the feature tensors given a ``.tsv`` file list. The feature and lengths tensors
will be stored under ``out_dir`` directory.
Args:
tsv_file (str or Path): The path of the tsv file.
out_dir (str or Path): The directory to store the feature tensors.
split (str): The split of data. Options: [``train``, ``valid``].
rank (int): The rank in the multi-processing.
num_rank (int): The number of ranks for multi-processing in feature extraction.
device (torch.device): The location to allocate for PyTorch Tensors.
Options: [``torch.device('cpu')``, torch.device('cuda')``].
feature_type (str, optional): The type of the desired feature. Options: [``mfcc``, ``hubert``].
(Default: ``mfcc``)
sample_rate (int, optional): The sample rate of the audio. (Default: 16000)
Returns:
None
"""
if
feature_type
not
in
[
"mfcc"
,
"hubert"
]:
raise
ValueError
(
"Unexpected feature type."
)
features
=
[]
lens
=
[]
out_dir
=
Path
(
out_dir
)
feat_path
,
len_path
=
_get_feat_lens_paths
(
out_dir
,
split
,
rank
,
num_rank
)
with
open
(
tsv_file
,
"r"
)
as
f
:
root
=
f
.
readline
().
rstrip
()
lines
=
[
line
.
rstrip
()
for
line
in
f
]
start
,
end
=
get_shard_range
(
len
(
lines
),
num_rank
,
rank
)
lines
=
lines
[
start
:
end
]
for
line
in
lines
:
path
,
nsample
=
line
.
split
(
"
\t
"
)
path
=
f
"
{
root
}
/
{
path
}
"
nsample
=
int
(
nsample
)
feature
=
extract_feature
(
path
,
device
,
feature_type
,
sample_rate
)
features
.
append
(
feature
.
cpu
())
lens
.
append
(
feature
.
shape
[
0
])
features
=
torch
.
cat
(
features
)
lens
=
torch
.
Tensor
(
lens
)
torch
.
save
(
features
,
feat_path
)
torch
.
save
(
lens
,
len_path
)
_LG
.
info
(
f
"Finished dumping features for rank
{
rank
}
of
{
num_rank
}
successfully"
)
examples/hubert/utils/kmeans.py
0 → 100644
View file @
4fa77623
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# https://github.com/pytorch/fairseq/blob/265df7144c79446f5ea8d835bda6e727f54dad9d/LICENSE
import
logging
from
pathlib
import
Path
from
typing
import
(
Tuple
,
)
import
joblib
import
torch
from
sklearn.cluster
import
MiniBatchKMeans
from
torch
import
Tensor
from
.common_utils
import
_get_feat_lens_paths
,
_get_model_path
_LG
=
logging
.
getLogger
(
__name__
)
def
load_feature
(
feat_dir
:
Path
,
split
:
str
,
num_rank
:
int
,
)
->
Tuple
[
Tensor
,
Tensor
]:
r
"""Loading features from pre-saved `.pt` files.
Args:
feat_dir (Path): The directory that stores the feature files.
split (str): The split of data. Options: [``train``, ``valid``].
num_rank (int): The number of ranks for multi-processing in feature extraction.
Returns:
(Tensor, Tensor)
Tensor: The concatenated feature tensor of shape `(frame, feature_dim)`.
Tensor: The lengths tensor of shape `(num_utterance,)`.
"""
feats
=
[]
lens
=
[]
for
rank
in
range
(
num_rank
):
feat_path
,
len_path
=
_get_feat_lens_paths
(
feat_dir
,
split
,
rank
,
num_rank
)
feat
=
torch
.
load
(
feat_path
)
length
=
torch
.
load
(
len_path
)
feats
.
append
(
feat
)
lens
.
append
(
length
)
feats
=
torch
.
cat
(
feats
)
lens
=
torch
.
cat
(
lens
)
return
feats
,
lens
def
learn_kmeans
(
feat_dir
:
Path
,
split
:
str
,
num_rank
:
int
,
km_dir
:
Path
,
n_clusters
:
int
,
init
:
str
=
"k-means++"
,
max_iter
:
int
=
100
,
batch_size
:
int
=
10000
,
tol
:
float
=
0.0
,
n_init
:
int
=
20
,
reassignment_ratio
:
float
=
0.0
,
max_no_improvement
:
int
=
100
,
)
->
None
:
r
"""Build and train the KMeans clustering model. The model is saved in "{km_dir}/model.pt"
Args:
feat_dir (Path): The directory that stores the feature files.
split (str): The split of data. Options: [``train``, ``valid``].
num_rank (int): The number of ranks for multi-processing in feature extraction.
km_dir (Path): The directory to store the KMeans clustering model.
n_clusters (int): The number of clusters.
init (str, optional): Method for initialization. Options: [``k-means++``, ``random``].
(Default: ``k-means++``)
max_iter (int, optional): Maximum number of iterations over the complete dataset. (Default: 100)
batch_size (int, optional): Batch size for training the KMeans clustering model. (Default: 10000)
tol (float, optional): Control early stopping based on the relative center changes as measured by a smoothed,
variance-normalized of the mean center squared position changes. (Default: 0.0)
n_init (int, optional): Number of random initializations that are tried. (Default: 20)
reassignment_ratio (float, optional): Control the fraction of the maximum number of counts for a center
to be reassigned. A higher value means that low count centers are more easily reassigned. (Default: 0.0)
max_no_improvement (int, optional): Control early stopping based on the consecutive number of mini batches
that does not yield an improvement on the smoothed inertia. (Default: 100)
Returns:
None
"""
if
not
km_dir
.
exists
():
km_dir
.
mkdir
()
km_model
=
MiniBatchKMeans
(
n_clusters
=
n_clusters
,
init
=
init
,
max_iter
=
max_iter
,
batch_size
=
batch_size
,
verbose
=
0
,
compute_labels
=
False
,
tol
=
tol
,
max_no_improvement
=
max_no_improvement
,
init_size
=
None
,
n_init
=
n_init
,
reassignment_ratio
=
reassignment_ratio
,
)
feats
,
_
=
load_feature
(
feat_dir
,
split
,
num_rank
,
)
feats
=
feats
.
numpy
()
km_model
.
fit
(
feats
)
km_path
=
_get_model_path
(
km_dir
)
joblib
.
dump
(
km_model
,
km_path
)
inertia
=
-
km_model
.
score
(
feats
)
/
len
(
feats
)
_LG
.
info
(
"Total intertia: %.5f"
,
inertia
)
_LG
.
info
(
"Finished training the KMeans clustering model successfully"
)
class
ApplyKmeans
(
object
):
def
__init__
(
self
,
km_path
,
device
):
self
.
km_model
=
joblib
.
load
(
km_path
)
self
.
C_np
=
self
.
km_model
.
cluster_centers_
.
transpose
()
self
.
Cnorm_np
=
(
self
.
C_np
**
2
).
sum
(
0
,
keepdims
=
True
)
self
.
C
=
torch
.
from_numpy
(
self
.
C_np
).
to
(
device
)
self
.
Cnorm
=
torch
.
from_numpy
(
self
.
Cnorm_np
).
to
(
device
)
def
__call__
(
self
,
x
):
dist
=
(
x
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
torch
.
matmul
(
x
,
self
.
C
)
+
self
.
Cnorm
)
return
dist
.
argmin
(
dim
=
1
).
cpu
().
numpy
()
def
get_km_label
(
feat_dir
:
Path
,
km_dir
:
Path
,
label_dir
:
Path
,
split
:
str
,
num_rank
:
int
,
device
:
torch
.
device
,
)
->
None
:
r
"""Predict the labels by the KMeans clustering model.
Args:
feat_dir (Path): The directory that stores the dumped features.
km_dir (Path): The directory that stores the KMeans model.
label_dir (Path): The directory to save the predicted labels.
split (str): The split of data. Options: [``train``, ``valid``].
num_rank (int): The number of ranks for multi-processing in feature extraction.
device (torch.device): The location to allocate for PyTorch Tensors.
Options: [``torch.device('cpu')``, torch.device('cuda')``].
Returns:
None
"""
if
not
label_dir
.
exists
():
label_dir
.
mkdir
()
km_path
=
_get_model_path
(
km_dir
)
label_path
=
label_dir
/
f
"label_
{
split
}
.pt"
apply_kmeans
=
ApplyKmeans
(
km_path
,
device
)
feats
,
lens
=
load_feature
(
feat_dir
,
split
,
num_rank
,
)
feats
=
feats
lens
=
lens
.
long
()
offset
=
0
assert
feats
.
shape
[
0
]
==
lens
.
sum
()
with
open
(
label_path
,
"w"
)
as
f
:
for
i
in
range
(
lens
.
shape
[
0
]):
feat
=
feats
[
offset
:
offset
+
lens
[
i
]].
to
(
device
)
offset
+=
lens
[
i
]
label
=
apply_kmeans
(
feat
).
tolist
()
f
.
write
(
" "
.
join
(
map
(
str
,
label
))
+
"
\n
"
)
_LG
.
info
(
"Finished predicting labels successfully"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment