Unverified Commit 1d21471c authored by Magnus Pierrau's avatar Magnus Pierrau Committed by GitHub
Browse files

Added mask_time_prob and mask_time_length arguments to wav2vec2 pretraining script (#20985)

Added mask_time_prob and mask_time_length arguments to wav2vec2 pretraining script and readme - new branch
parent bc53fc62
......@@ -79,6 +79,8 @@ accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--adam_beta2="0.98" \
--adam_epsilon="1e-06" \
--gradient_checkpointing \
--mask_time_prob="0.65" \
--mask_time_length="10"
```
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch).
......@@ -110,6 +112,8 @@ accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--adam_beta2="0.98" \
--adam_epsilon="1e-06" \
--gradient_checkpointing \
--mask_time_prob="0.65" \
--mask_time_length="10"
```
The experiment was run on 8 GPU V100 (16 GB RAM each) for 4 days.
......@@ -146,6 +150,8 @@ accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--adam_beta2=0.98 \
--adam_epsilon=1e-06 \
--gradient_checkpointing \
--mask_time_prob=0.65 \
--mask_time_length=10
```
The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days.
......
......@@ -247,6 +247,24 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--mask_time_prob",
type=float,
default=None,
help=(
"Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked in the"
" contrastive task. If omitted, will pull value from model config."
),
)
parser.add_argument(
"--mask_time_length",
type=int,
default=None,
help=(
"Length of each vector mask span to mask along the time axis in the contrastive task."
" If omitted, will pull value from model config."
),
)
args = parser.parse_args()
if args.push_to_hub:
......@@ -285,12 +303,22 @@ class DataCollatorForWav2Vec2Pretraining:
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
mask_time_prob (:obj:`float`, `optional`, defaults to :obj:`0.65`):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked for the contrastive task.
Note that overlap between masked sequences may decrease the actual percentage of masked vectors.
The default value is taken from the original wav2vec 2.0 article (https://arxiv.org/abs/2006.11477),
and results in about 49 percent of each sequence being masked on average.
mask_time_length (:obj:`int`, `optional`, defaults to :obj:`10`):
Length of each vector mask span to mask along the time axis in the contrastive task. The default value
originates from the original wav2vec 2.0 article and corresponds to the ``M`` variable mentioned there.
"""
model: Wav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
mask_time_prob: Optional[float] = 0.65
mask_time_length: Optional[int] = 10
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
......@@ -320,8 +348,8 @@ class DataCollatorForWav2Vec2Pretraining:
# sample randomly masked indices
mask_time_indices = _compute_mask_indices(
features_shape,
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
self.mask_time_prob,
self.mask_time_length,
attention_mask=batch.get("sub_attention_mask"),
)
......@@ -515,8 +543,16 @@ def main():
model.gradient_checkpointing_enable()
# 4. Define data collator, optimizer and scheduler
mask_time_prob = config.mask_time_prob if args.mask_time_prob is None else args.mask_time_prob
mask_time_length = config.mask_time_length if args.mask_time_length is None else args.mask_time_length
data_collator = DataCollatorForWav2Vec2Pretraining(
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of
model=model,
feature_extractor=feature_extractor,
pad_to_multiple_of=args.pad_to_multiple_of,
mask_time_prob=mask_time_prob,
mask_time_length=mask_time_length,
)
train_dataloader = DataLoader(
vectorized_datasets["train"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment