Commit c4f12526 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add HuBERT-feature support in preprocessing of HuBERT training (#2143)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2143

Reviewed By: carolineechen

Differential Revision: D34722238

Pulled By: nateanl

fbshipit-source-id: 72809c9db91c94d8e853c80ed8522eeffe5ff136
parent a92ae368
...@@ -41,7 +41,19 @@ def _parse_args(): ...@@ -41,7 +41,19 @@ def _parse_args():
help="The path to the directory where the directory ``LibriSpeech`` or ``LibriLight`` is stored.", help="The path to the directory where the directory ``LibriSpeech`` or ``LibriLight`` is stored.",
) )
parser.add_argument("--num-rank", default=5, type=int) parser.add_argument("--num-rank", default=5, type=int)
parser.add_argument("--feat-type", default="mfcc", type=str) parser.add_argument("--feat-type", default="mfcc", choices=["mfcc", "hubert"], type=str)
parser.add_argument(
"--layer-index",
default=6,
type=int,
help="The layer index in HuBERT model for feature extraction. (``1`` means the first layer output)",
)
parser.add_argument(
"--checkpoint-path",
default=None,
type=Path,
help="The model checkpoint of hubert_pretrain_base model.",
)
parser.add_argument("--use-gpu", default=False, type=bool) parser.add_argument("--use-gpu", default=False, type=bool)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
...@@ -63,10 +75,16 @@ def main(args): ...@@ -63,10 +75,16 @@ def main(args):
if not args.exp_dir.exists(): if not args.exp_dir.exists():
args.exp_dir.mkdir() args.exp_dir.mkdir()
tsv_dir = args.exp_dir / "tsv" if args.feat_type == "mfcc":
feat_dir = args.exp_dir / args.feat_type data_dir = args.exp_dir / "data" / "mfcc"
km_dir = args.exp_dir / "km_model" else:
label_dir = args.exp_dir / "label" data_dir = args.exp_dir / "data" / f"{args.feat_type}_{args.layer_index}"
data_dir.mkdir(parents=True, exist_ok=True)
tsv_dir = data_dir / "tsv"
feat_dir = data_dir / "feat"
km_dir = data_dir / "km_model"
label_dir = data_dir / "label"
if args.use_gpu: if args.use_gpu:
device = torch.device("cuda") device = torch.device("cuda")
...@@ -91,9 +109,11 @@ def main(args): ...@@ -91,9 +109,11 @@ def main(args):
args.num_rank, args.num_rank,
device, device,
args.feat_type, args.feat_type,
args.layer_index,
args.checkpoint_path,
16_000, 16_000,
) )
for rank in range(args.num_rank) for rank in range(1, args.num_rank + 1)
] ]
_ = p.starmap(dump_features, inputs) _ = p.starmap(dump_features, inputs)
p.close() p.close()
...@@ -108,7 +128,7 @@ def main(args): ...@@ -108,7 +128,7 @@ def main(args):
args.num_cluster, args.num_cluster,
) )
# Predict labels for MFCC features # Predict labels for MFCC or HuBERT features
for split in ["train", "valid"]: for split in ["train", "valid"]:
get_km_label( get_km_label(
feat_dir, feat_dir,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Optional,
Tuple, Tuple,
Union, Union,
) )
...@@ -13,6 +14,7 @@ from typing import ( ...@@ -13,6 +14,7 @@ from typing import (
import torch import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.nn import Module
from .common_utils import _get_feat_lens_paths from .common_utils import _get_feat_lens_paths
...@@ -31,26 +33,24 @@ def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int] ...@@ -31,26 +33,24 @@ def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]
int: The start index for the current rank. int: The start index for the current rank.
int: The end index for the current rank. int: The end index for the current rank.
""" """
assert 0 <= rank < num_rank, f"invalid rank/num_rank {rank}/{num_rank}" assert 1 <= rank <= num_rank, f"invalid rank/num_rank {rank}/{num_rank}"
assert num_lines > 0, f"Found {num_lines} files, make sure you specify the correct root directory" assert num_lines > 0, f"Found {num_lines} files, make sure you specify the correct root directory"
start = round(num_lines / num_rank * rank) start = round(num_lines / num_rank * (rank - 1))
end = round(num_lines / num_rank * (rank + 1)) end = round(num_lines / num_rank * rank)
_LG.info(f"rank {rank} of {num_rank}, process {end-start} " f"({start}-{end}) out of {num_lines}") _LG.info(f"rank {rank} of {num_rank}, process {end-start} " f"({start}-{end}) out of {num_lines}")
return start, end return start, end
def extract_feature( def extract_feature_mfcc(
path: str, path: str,
device: torch.device, device: torch.device,
feature_type: str,
sample_rate: int, sample_rate: int,
) -> Tensor: ) -> Tensor:
r"""Extract features for KMeans clustering and pseudo label prediction. r"""Extract MFCC features for KMeans clustering and pseudo label prediction.
Args: Args:
path (str): The file path of the audio. path (str): The file path of the audio.
device (torch.device): The location to allocate for PyTorch Tensors. device (torch.device): The location to allocate for PyTorch Tensors.
Options: [``torch.device('cpu')``, torch.device('cuda')``]. Options: [``torch.device('cpu')``, torch.device('cuda')``].
feature_type (str): The type of the desired feature. Options: [``mfcc``, ``hubert``].
sample_rate (int): The sample rate of the audio. sample_rate (int): The sample rate of the audio.
Returns: Returns:
...@@ -58,23 +58,66 @@ def extract_feature( ...@@ -58,23 +58,66 @@ def extract_feature(
""" """
waveform, sr = torchaudio.load(path) waveform, sr = torchaudio.load(path)
assert sr == sample_rate assert sr == sample_rate
feature_extractor = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=13, melkwargs={"n_fft": 400, "hop_length": 160, "center": False}
).to(device)
waveform = waveform[0].to(device) waveform = waveform[0].to(device)
if feature_type == "mfcc": mfccs = feature_extractor(waveform) # (freq, time)
feature_extractor = torchaudio.transforms.MFCC( # mfccs = torchaudio.compliance.kaldi.mfcc(
sample_rate=sample_rate, n_mfcc=13, melkwargs={"n_fft": 400, "hop_length": 160, "center": False} # waveform=waveform,
).to(device) # sample_frequency=sample_rate,
mfccs = feature_extractor(waveform) # (freq, time) # use_energy=False,
# mfccs = torchaudio.compliance.kaldi.mfcc( # ) # (time, freq)
# waveform=waveform, # mfccs = mfccs.transpose(0, 1) # (freq, time)
# sample_frequency=sample_rate, deltas = torchaudio.functional.compute_deltas(mfccs)
# use_energy=False, ddeltas = torchaudio.functional.compute_deltas(deltas)
# ) # (time, freq) concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
# mfccs = mfccs.transpose(0, 1) # (freq, time) feat = concat.transpose(0, 1) # (time, freq)
deltas = torchaudio.functional.compute_deltas(mfccs) return feat
ddeltas = torchaudio.functional.compute_deltas(deltas)
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
concat = concat.transpose(0, 1) # (time, freq) def extract_feature_hubert(
return concat path: str,
device: torch.device,
sample_rate: int,
model: Module,
layer_index: int,
) -> Tensor:
r"""Extract HuBERT features for KMeans clustering and pseudo label prediction.
Args:
path (str): The file path of the audio.
device (torch.device): The location to allocate for PyTorch Tensors.
Options: [``torch.device('cpu')``, torch.device('cuda')``].
sample_rate (int): The sample rate of the audio.
model (Module): The loaded ``HuBERTPretrainModel`` model.
layer_index (int): The index of transformer layers in
``torchaudio.models.HuBERTPretrainModel`` for extracting features.
(``1`` means the first layer output).
Returns:
Tensor: The desired feature tensor of the given audio file.
"""
waveform, sr = torchaudio.load(path)
assert sr == sample_rate
waveform = waveform.to(device)
with torch.inference_mode():
feat = model.wav2vec2.extract_features(waveform, num_layers=layer_index)[0][-1][0] # (time, feat_dim)
return feat
def _load_state(model: Module, checkpoint_path: Path) -> Module:
"""Load weights from HuBERTPretrainModel checkpoint into hubert_pretrain_base model.
Args:
model (Module): The hubert_pretrain_base model.
checkpoint_path (Path): The model checkpoint.
Returns:
(Module): The pretrained model.
"""
state_dict = torch.load(checkpoint_path)
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
model.load_state_dict(state_dict)
return model
def dump_features( def dump_features(
...@@ -85,6 +128,8 @@ def dump_features( ...@@ -85,6 +128,8 @@ def dump_features(
num_rank: int, num_rank: int,
device: torch.device, device: torch.device,
feature_type: str = "mfcc", feature_type: str = "mfcc",
layer_index: Optional[int] = None,
checkpoint_path: Optional[Path] = None,
sample_rate: int = 16_000, sample_rate: int = 16_000,
) -> None: ) -> None:
r"""Dump the feature tensors given a ``.tsv`` file list. The feature and lengths tensors r"""Dump the feature tensors given a ``.tsv`` file list. The feature and lengths tensors
...@@ -99,18 +144,34 @@ def dump_features( ...@@ -99,18 +144,34 @@ def dump_features(
Options: [``torch.device('cpu')``, torch.device('cuda')``]. Options: [``torch.device('cpu')``, torch.device('cuda')``].
feature_type (str, optional): The type of the desired feature. Options: [``mfcc``, ``hubert``]. feature_type (str, optional): The type of the desired feature. Options: [``mfcc``, ``hubert``].
(Default: ``mfcc``) (Default: ``mfcc``)
sample_rate (int, optional): The sample rate of the audio. (Default: 16000) layer_index (int or None, optional): The index of transformer layers in
``torchaudio.models.HuBERTPretrainModel`` for extracting features.
(``1`` means the first layer output). Only active when ``feature_type``
is set to ``hubert``. (Default: ``None``)
checkpoint_path(Path or None, optional): The checkpoint path of ``torchaudio.models.HuBERTPretrainModel``.
Only active when ``feature_type`` is set to ``hubert``. (Default: ``None``)
sample_rate (int, optional): The sample rate of the audio. (Default: ``16000``)
Returns: Returns:
None None
""" """
if feature_type not in ["mfcc", "hubert"]: if feature_type not in ["mfcc", "hubert"]:
raise ValueError("Unexpected feature type.") raise ValueError(f"Expected feature type to be 'mfcc' or 'hubert'. Found {feature_type}.")
if feature_type == "hubert" and layer_index is None:
assert ValueError("Please set the layer_index for HuBERT feature.")
features = [] features = []
lens = [] lens = []
out_dir = Path(out_dir) out_dir = Path(out_dir)
feat_path, len_path = _get_feat_lens_paths(out_dir, split, rank, num_rank) feat_path, len_path = _get_feat_lens_paths(out_dir, split, rank, num_rank)
if feature_type == "hubert":
from torchaudio.models import hubert_pretrain_base
model = hubert_pretrain_base()
model = _load_state(model, checkpoint_path)
model.to(device)
with open(tsv_file, "r") as f: with open(tsv_file, "r") as f:
root = f.readline().rstrip() root = f.readline().rstrip()
lines = [line.rstrip() for line in f] lines = [line.rstrip() for line in f]
...@@ -120,7 +181,10 @@ def dump_features( ...@@ -120,7 +181,10 @@ def dump_features(
path, nsample = line.split("\t") path, nsample = line.split("\t")
path = f"{root}/{path}" path = f"{root}/{path}"
nsample = int(nsample) nsample = int(nsample)
feature = extract_feature(path, device, feature_type, sample_rate) if feature_type == "mfcc":
feature = extract_feature_mfcc(path, device, sample_rate)
else:
feature = extract_feature_hubert(path, device, sample_rate, model, layer_index)
features.append(feature.cpu()) features.append(feature.cpu())
lens.append(feature.shape[0]) lens.append(feature.shape[0])
features = torch.cat(features) features = torch.cat(features)
......
...@@ -37,7 +37,7 @@ def load_feature( ...@@ -37,7 +37,7 @@ def load_feature(
""" """
feats = [] feats = []
lens = [] lens = []
for rank in range(num_rank): for rank in range(1, num_rank + 1):
feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank) feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank)
feat = torch.load(feat_path) feat = torch.load(feat_path)
length = torch.load(len_path) length = torch.load(len_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment