Commit 40ff642e authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix hubert fine-tuning recipe (#2851)

Summary:
- `_get_fileids_paths` in `LibriLightLimited` dataset was changed dataset in https://github.com/pytorch/audio/issues/2653, the absolute path becomes relative paths. This PR fixes the usage in hubert fine-tuning recipe to get correct audio paths.
- model options should be `hubert_pretrain_large` and `hubert_pretrain_xlarge` instead of `hubert_large` and `hubert_xlarge`.
- The input dimension of CTC linear layer varies depending on the model architecture, update it in lightning module.

cc simpleoier

Pull Request resolved: https://github.com/pytorch/audio/pull/2851

Reviewed By: carolineechen

Differential Revision: D41327998

Pulled By: nateanl

fbshipit-source-id: f92248ee84ec860b4e4dbef880c5794b338e1e2d
parent 26f62dc5
...@@ -415,13 +415,13 @@ class CollateFnHubert: ...@@ -415,13 +415,13 @@ class CollateFnHubert:
return waveforms, labels, lengths return waveforms, labels, lengths
def _get_lengths_librilightlimited(files: List[str]) -> List[int]: def _get_lengths_librilightlimited(files: List[str], path: str, ext_audio: str) -> List[int]:
lengths = [] lengths = []
for file_path, fileid in files: for file_path, fileid in files:
speaker_id, chapter_id, utterance_id = fileid.split("-") speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio # Load audio
file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}.flac" file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}{ext_audio}"
file_audio = os.path.join(file_path, speaker_id, chapter_id, file_audio) file_audio = os.path.join(path, file_path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames length = torchaudio.info(file_audio).num_frames
lengths.append(length) lengths.append(length)
return lengths return lengths
......
...@@ -360,7 +360,8 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -360,7 +360,8 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes, num_classes=num_classes,
) )
elif model_name == "hubert_large": self.aux = torch.nn.Linear(768, aux_num_out)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large( self.model = torchaudio.models.hubert_pretrain_large(
encoder_projection_dropout=encoder_projection_dropout, encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout, encoder_attention_dropout=encoder_attention_dropout,
...@@ -372,7 +373,8 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -372,7 +373,8 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes, num_classes=num_classes,
) )
elif model_name == "hubert_xlarge": self.aux = torch.nn.Linear(1024, aux_num_out)
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge( self.model = torchaudio.models.hubert_pretrain_xlarge(
encoder_projection_dropout=encoder_projection_dropout, encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout, encoder_attention_dropout=encoder_attention_dropout,
...@@ -384,9 +386,9 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -384,9 +386,9 @@ class HuBERTFineTuneModule(LightningModule):
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes, num_classes=num_classes,
) )
self.aux = torch.nn.Linear(1280, aux_num_out)
else: else:
raise ValueError(f"Unsupported model name: {model_name}.") raise ValueError(f"Unsupported model name: {model_name}.")
self.aux = torch.nn.Linear(768, aux_num_out)
self._load_checkpoint(checkpoint) self._load_checkpoint(checkpoint)
for p in self.model.wav2vec2.feature_extractor.parameters(): for p in self.model.wav2vec2.feature_extractor.parameters():
p.requires_grad = False p.requires_grad = False
...@@ -504,7 +506,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -504,7 +506,7 @@ class HuBERTFineTuneModule(LightningModule):
def train_dataloader(self): def train_dataloader(self):
dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset) dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths) lengths = _get_lengths_librilightlimited(dataset._fileids_paths, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler( sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True
) )
......
...@@ -16,12 +16,24 @@ _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af" ...@@ -16,12 +16,24 @@ _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
_SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]} _SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]}
def _get_fileids_paths(path, folders, _ext_audio) -> List[Tuple[str, str]]: def _get_fileids_paths(path: Path, folders: List[str], _ext_audio: str) -> List[Tuple[str, str]]:
"""Get the file names and the corresponding file paths without `speaker_id` """Get the file names and the corresponding file paths without `speaker_id`
and `chapter_id` directories. and `chapter_id` directories.
The format of path is like: The format of path is like:
{root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
{root}/{_ARCHIVE_NAME}/9h/[clean, other] {root}/{_ARCHIVE_NAME}/9h/[clean, other]
Args:
path (Path): Root path to the dataset.
folders (List[str]): Folders that contain the desired audio files.
_ext_audio (str): Extension of audio files.
Returns:
List[Tuple[str, str]]:
List of tuples where the first element is the relative path to the audio file.
The format of relative path is like:
1h/[0-5]/[clean, other] or 9h/[clean, other]
The second element is the file name without audio extension.
""" """
path = Path(path) path = Path(path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment