Commit 980528e9 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Fix hubert fine-tuning recipe (#2851)

Summary:
- `_get_fileids_paths` in `LibriLightLimited` dataset was changed dataset in https://github.com/pytorch/audio/issues/2653, the absolute path becomes relative paths. This PR fixes the usage in hubert fine-tuning recipe to get correct audio paths.
- model options should be `hubert_pretrain_large` and `hubert_pretrain_xlarge` instead of `hubert_large` and `hubert_xlarge`.
- The input dimension of CTC linear layer varies depending on the model architecture, update it in lightning module.

cc simpleoier

Pull Request resolved: https://github.com/pytorch/audio/pull/2851

Reviewed By: carolineechen

Differential Revision: D41327998

Pulled By: nateanl

fbshipit-source-id: f92248ee84ec860b4e4dbef880c5794b338e1e2d
parent 2574e114
......@@ -415,13 +415,13 @@ class CollateFnHubert:
return waveforms, labels, lengths
def _get_lengths_librilightlimited(files: List[str]) -> List[int]:
def _get_lengths_librilightlimited(files: List[str], path: str, ext_audio: str) -> List[int]:
lengths = []
for file_path, fileid in files:
speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio
file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}.flac"
file_audio = os.path.join(file_path, speaker_id, chapter_id, file_audio)
file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}{ext_audio}"
file_audio = os.path.join(path, file_path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames
lengths.append(length)
return lengths
......
......@@ -360,7 +360,8 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
elif model_name == "hubert_large":
self.aux = torch.nn.Linear(768, aux_num_out)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
......@@ -372,7 +373,8 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
elif model_name == "hubert_xlarge":
self.aux = torch.nn.Linear(1024, aux_num_out)
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
......@@ -384,9 +386,9 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
self.aux = torch.nn.Linear(1280, aux_num_out)
else:
raise ValueError(f"Unsupported model name: {model_name}.")
self.aux = torch.nn.Linear(768, aux_num_out)
self._load_checkpoint(checkpoint)
for p in self.model.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
......@@ -504,7 +506,7 @@ class HuBERTFineTuneModule(LightningModule):
def train_dataloader(self):
dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True
)
......
......@@ -16,12 +16,24 @@ _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
_SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]}
def _get_fileids_paths(path, folders, _ext_audio) -> List[Tuple[str, str]]:
def _get_fileids_paths(path: Path, folders: List[str], _ext_audio: str) -> List[Tuple[str, str]]:
"""Get the file names and the corresponding file paths without `speaker_id`
and `chapter_id` directories.
The format of path is like:
{root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
{root}/{_ARCHIVE_NAME}/9h/[clean, other]
Args:
path (Path): Root path to the dataset.
folders (List[str]): Folders that contain the desired audio files.
_ext_audio (str): Extension of audio files.
Returns:
List[Tuple[str, str]]:
List of tuples where the first element is the relative path to the audio file.
The format of relative path is like:
1h/[0-5]/[clean, other] or 9h/[clean, other]
The second element is the file name without audio extension.
"""
path = Path(path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment