Unverified Commit 9f831bde authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DocTests Speech] Add doc tests for all speech models (#15031)

* fix_torch_device_generate_test

* remove @

* doc tests

* up

* up

* fix doctests

* adapt files

* finish refactor

* up

* save intermediate

* add more logic

* new change

* improve

* next try

* next try

* next try

* next try

* fix final spaces

* fix final spaces

* improve

* renaming

* correct more bugs

* finish wavlm

* add comment

* run on test runner

* finish all speech models

* adapt

* finish
parent 4df69506
...@@ -19,7 +19,7 @@ env: ...@@ -19,7 +19,7 @@ env:
jobs: jobs:
run_doctests: run_doctests:
runs-on: [self-hosted, docker-gpu, single-gpu] runs-on: [self-hosted, docker-gpu-test, single-gpu]
container: container:
image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
...@@ -35,8 +35,16 @@ jobs: ...@@ -35,8 +35,16 @@ jobs:
run: | run: |
apt -y update && apt install -y libsndfile1-dev apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip pip install --upgrade pip
pip install .[dev] pip install .[testing,torch-speech]
- name: Prepare files for doctests
run: |
python utils/prepare_for_doc_test.py src docs
- name: Run doctests - name: Run doctests
run: | run: |
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"
- name: Clean files after doctests
run: |
python utils/prepare_for_doc_test.py src docs --remove_new_line
...@@ -1127,9 +1127,11 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r""" ...@@ -1127,9 +1127,11 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
```python ```python
>>> from transformers import {processor_class}, {model_class} >>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate >>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> processor = {processor_class}.from_pretrained("{checkpoint}")
...@@ -1137,9 +1139,12 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r""" ...@@ -1137,9 +1139,12 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
>>> # audio file is decoded on the fly >>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> outputs = model(**inputs) >>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state >>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
``` ```
""" """
...@@ -1152,6 +1157,7 @@ PT_SPEECH_CTC_SAMPLE = r""" ...@@ -1152,6 +1157,7 @@ PT_SPEECH_CTC_SAMPLE = r"""
>>> import torch >>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate >>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> processor = {processor_class}.from_pretrained("{checkpoint}")
...@@ -1159,17 +1165,24 @@ PT_SPEECH_CTC_SAMPLE = r""" ...@@ -1159,17 +1165,24 @@ PT_SPEECH_CTC_SAMPLE = r"""
>>> # audio file is decoded on the fly >>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> logits = model(**inputs).logits >>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_ids = torch.argmax(logits, dim=-1) >>> predicted_ids = torch.argmax(logits, dim=-1)
>>> # transcribe speech >>> # transcribe speech
>>> transcription = processor.batch_decode(predicted_ids) >>> transcription = processor.batch_decode(predicted_ids)
>>> transcription[0]
{expected_output}
```
>>> # compute loss ```python
>>> with processor.as_target_processor(): >>> with processor.as_target_processor():
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids ... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss >>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
{expected_loss}
``` ```
""" """
...@@ -1182,21 +1195,31 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r""" ...@@ -1182,21 +1195,31 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
>>> import torch >>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate >>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly >>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt") >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> logits = model(**inputs).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1) >>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
>>> predicted_label = model.config.id2label[predicted_class_ids] >>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
{expected_output}
```
```python
>>> # compute loss - target_label is e.g. "down" >>> # compute loss - target_label is e.g. "down"
>>> target_label = model.config.id2label[0] >>> target_label = model.config.id2label[0]
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]]) >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
>>> loss = model(**inputs).loss >>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
{expected_loss}
``` ```
""" """
...@@ -1210,17 +1233,22 @@ PT_SPEECH_FRAME_CLASS_SAMPLE = r""" ...@@ -1210,17 +1233,22 @@ PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
>>> import torch >>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate >>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly >>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt") >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
>>> logits = model(**inputs).logits >>> with torch.no_grad():
... logits = model(**inputs).logits
>>> probabilities = torch.sigmoid(logits[0]) >>> probabilities = torch.sigmoid(logits[0])
>>> # labels is a one-hot array of shape (num_frames, num_speakers) >>> # labels is a one-hot array of shape (num_frames, num_speakers)
>>> labels = (probabilities > 0.5).long() >>> labels = (probabilities > 0.5).long()
>>> labels[0].tolist()
{expected_output}
``` ```
""" """
...@@ -1234,14 +1262,19 @@ PT_SPEECH_XVECTOR_SAMPLE = r""" ...@@ -1234,14 +1262,19 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
>>> import torch >>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate >>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly >>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt") >>> inputs = feature_extractor(
>>> embeddings = model(**inputs).embeddings ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
... )
>>> with torch.no_grad():
... embeddings = model(**inputs).embeddings
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
>>> # the resulting embeddings can be used for cosine similarity-based retrieval >>> # the resulting embeddings can be used for cosine similarity-based retrieval
...@@ -1250,6 +1283,8 @@ PT_SPEECH_XVECTOR_SAMPLE = r""" ...@@ -1250,6 +1283,8 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
>>> threshold = 0.7 # the optimal threshold is dataset-dependent >>> threshold = 0.7 # the optimal threshold is dataset-dependent
>>> if similarity < threshold: >>> if similarity < threshold:
... print("Speakers are not the same!") ... print("Speakers are not the same!")
>>> round(similarity.item(), 2)
{expected_output}
``` ```
""" """
...@@ -1553,9 +1588,11 @@ def add_code_sample_docstrings( ...@@ -1553,9 +1588,11 @@ def add_code_sample_docstrings(
checkpoint=None, checkpoint=None,
output_type=None, output_type=None,
config_class=None, config_class=None,
mask=None, mask="[MASK]",
model_cls=None, model_cls=None,
modality=None modality=None,
expected_output="",
expected_loss="",
): ):
def docstring_decorator(fn): def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise # model_class defaults to function's class if not specified otherwise
...@@ -1568,7 +1605,17 @@ def add_code_sample_docstrings( ...@@ -1568,7 +1605,17 @@ def add_code_sample_docstrings(
else: else:
sample_docstrings = PT_SAMPLE_DOCSTRINGS sample_docstrings = PT_SAMPLE_DOCSTRINGS
doc_kwargs = dict(model_class=model_class, processor_class=processor_class, checkpoint=checkpoint) # putting all kwargs for docstrings in a dict to be used
# with the `.format(**doc_kwargs)`. Note that string might
# be formatted with non-existing keys, which is fine.
doc_kwargs = dict(
model_class=model_class,
processor_class=processor_class,
checkpoint=checkpoint,
mask=mask,
expected_output=expected_output,
expected_loss=expected_loss,
)
if "SequenceClassification" in model_class and modality == "audio": if "SequenceClassification" in model_class and modality == "audio":
code_sample = sample_docstrings["AudioClassification"] code_sample = sample_docstrings["AudioClassification"]
...@@ -1581,7 +1628,6 @@ def add_code_sample_docstrings( ...@@ -1581,7 +1628,6 @@ def add_code_sample_docstrings(
elif "MultipleChoice" in model_class: elif "MultipleChoice" in model_class:
code_sample = sample_docstrings["MultipleChoice"] code_sample = sample_docstrings["MultipleChoice"]
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]: elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
code_sample = sample_docstrings["MaskedLM"] code_sample = sample_docstrings["MaskedLM"]
elif "LMHead" in model_class or "CausalLM" in model_class: elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = sample_docstrings["LMHead"] code_sample = sample_docstrings["LMHead"]
......
...@@ -40,15 +40,29 @@ from .configuration_hubert import HubertConfig ...@@ -40,15 +40,29 @@ from .configuration_hubert import HubertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "HubertConfig"
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
# General docstring
_CONFIG_FOR_DOC = "HubertConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 22.68
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 8.53
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/hubert-base-ls960", "facebook/hubert-base-ls960",
...@@ -1098,6 +1112,8 @@ class HubertForCTC(HubertPreTrainedModel): ...@@ -1098,6 +1112,8 @@ class HubertForCTC(HubertPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1228,6 +1244,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel): ...@@ -1228,6 +1244,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
......
...@@ -36,16 +36,33 @@ from .configuration_sew import SEWConfig ...@@ -36,16 +36,33 @@ from .configuration_sew import SEWConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SEWConfig"
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
# General docstring
_CONFIG_FOR_DOC = "SEWConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
# CTC docstring
_CTC_EXPECTED_OUTPUT = (
"'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
)
_CTC_EXPECTED_LOSS = 0.42
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 9.52
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [ SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
"asapp/sew-tiny-100k", "asapp/sew-tiny-100k",
"asapp/sew-small-100k", "asapp/sew-small-100k",
...@@ -879,6 +896,7 @@ class SEWModel(SEWPreTrainedModel): ...@@ -879,6 +896,7 @@ class SEWModel(SEWPreTrainedModel):
output_type=BaseModelOutput, output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -978,6 +996,8 @@ class SEWForCTC(SEWPreTrainedModel): ...@@ -978,6 +996,8 @@ class SEWForCTC(SEWPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1108,6 +1128,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel): ...@@ -1108,6 +1128,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
......
...@@ -37,14 +37,27 @@ from .configuration_sew_d import SEWDConfig ...@@ -37,14 +37,27 @@ from .configuration_sew_d import SEWDConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 1
# General docstring
_CONFIG_FOR_DOC = "SEWDConfig" _CONFIG_FOR_DOC = "SEWDConfig"
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k" # Base docstring
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
_HIDDEN_STATES_START_POSITION = 1 # CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 0.21
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 3.16
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [ SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
"asapp/sew-d-tiny-100k", "asapp/sew-d-tiny-100k",
...@@ -1415,6 +1428,7 @@ class SEWDModel(SEWDPreTrainedModel): ...@@ -1415,6 +1428,7 @@ class SEWDModel(SEWDPreTrainedModel):
output_type=BaseModelOutput, output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -1514,6 +1528,8 @@ class SEWDForCTC(SEWDPreTrainedModel): ...@@ -1514,6 +1528,8 @@ class SEWDForCTC(SEWDPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1644,6 +1660,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel): ...@@ -1644,6 +1660,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
......
...@@ -42,14 +42,26 @@ from .configuration_unispeech import UniSpeechConfig ...@@ -42,14 +42,26 @@ from .configuration_unispeech import UniSpeechConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "UniSpeechConfig" _CONFIG_FOR_DOC = "UniSpeechConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv" # Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
_HIDDEN_STATES_START_POSITION = 2 # CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
_CTC_EXPECTED_LOSS = 17.17
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.66 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [ UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/unispeech-large-1500h-cv", "microsoft/unispeech-large-1500h-cv",
...@@ -1129,6 +1141,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): ...@@ -1129,6 +1141,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
output_type=UniSpeechBaseModelOutput, output_type=UniSpeechBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -1266,44 +1279,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ...@@ -1266,44 +1279,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining >>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices >>> from transformers.models.unispeech.modeling_unispeech import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
... "hf-internal-testing/tiny-random-unispeech-sat"
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base") ... )
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> # TODO: Add full pretraining example
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
>>> # show that cosine similarity is much higher than random
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
>>> # for contrastive loss training model should be put into train mode
>>> model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1406,6 +1389,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1406,6 +1389,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1536,6 +1521,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1536,6 +1521,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
......
...@@ -43,16 +43,33 @@ from .configuration_unispeech_sat import UniSpeechSatConfig ...@@ -43,16 +43,33 @@ from .configuration_unispeech_sat import UniSpeechSatConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "UniSpeechSatConfig" _CONFIG_FOR_DOC = "UniSpeechSatConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 39.88
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech-sat"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.71 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus" # Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" _FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" _FRAME_EXPECTED_OUTPUT = [0, 0]
_HIDDEN_STATES_START_POSITION = 2 # Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [ UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat # See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
...@@ -1163,6 +1180,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ...@@ -1163,6 +1180,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
output_type=UniSpeechSatBaseModelOutput, output_type=UniSpeechSatBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -1300,42 +1318,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ...@@ -1300,42 +1318,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining >>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base") >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base") >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
>>> # TODO: Add full pretraining example
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
>>> # show that cosine similarity is much higher than random
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
>>> # for contrastive loss training model should be put into train mode
>>> model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1431,6 +1417,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1431,6 +1417,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1561,6 +1549,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1561,6 +1549,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1677,6 +1667,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ...@@ -1677,6 +1667,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
...@@ -1853,6 +1844,7 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel): ...@@ -1853,6 +1844,7 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
...@@ -48,16 +48,34 @@ from .configuration_wav2vec2 import Wav2Vec2Config ...@@ -48,16 +48,34 @@ from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2Config" _CONFIG_FOR_DOC = "Wav2Vec2Config"
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 53.48
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks" _SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv" _SEQ_CLASS_EXPECTED_LOSS = 6.54
_HIDDEN_STATES_START_POSITION = 2 # Frame class docstring
_FRAME_CLASS_CHECKPOINT = "anton-l/wav2vec2-base-superb-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.98
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1294,6 +1312,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1294,6 +1312,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -1469,10 +1488,11 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1469,10 +1488,11 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
>>> # show that cosine similarity is much higher than random >>> # show that cosine similarity is much higher than random
>>> assert cosine_sim[mask_time_indices].mean() > 0.5 >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
tensor(True)
>>> # for contrastive loss training model should be put into train mode >>> # for contrastive loss training model should be put into train mode
>>> model.train() >>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
```""" ```"""
...@@ -1697,6 +1717,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -1697,6 +1717,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1826,6 +1848,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): ...@@ -1826,6 +1848,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1941,6 +1965,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): ...@@ -1941,6 +1965,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
...@@ -2114,6 +2139,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): ...@@ -2114,6 +2139,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
...@@ -42,18 +42,34 @@ from .configuration_wavlm import WavLMConfig ...@@ -42,18 +42,34 @@ from .configuration_wavlm import WavLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "WavLMConfig" _CONFIG_FOR_DOC = "WavLMConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" _CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base" # CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 12.51
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-wavlm"
_SEQ_CLASS_EXPECTED_OUTPUT = "'no'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.7 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus" # Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" _FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" _FRAME_EXPECTED_OUTPUT = [0, 0]
_HIDDEN_STATES_START_POSITION = 2 # Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/wavlm-base", "microsoft/wavlm-base",
...@@ -1247,6 +1263,7 @@ class WavLMModel(WavLMPreTrainedModel): ...@@ -1247,6 +1263,7 @@ class WavLMModel(WavLMPreTrainedModel):
output_type=WavLMBaseModelOutput, output_type=WavLMBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward( def forward(
self, self,
...@@ -1350,6 +1367,8 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1350,6 +1367,8 @@ class WavLMForCTC(WavLMPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1480,6 +1499,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ...@@ -1480,6 +1499,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1596,6 +1617,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ...@@ -1596,6 +1617,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
...@@ -1772,6 +1794,7 @@ class WavLMForXVector(WavLMPreTrainedModel): ...@@ -1772,6 +1794,7 @@ class WavLMForXVector(WavLMPreTrainedModel):
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
docs/source/quicktour.rst src/transformers/models/wav2vec2/modeling_wav2vec2.py
docs/source/task_summary.rst src/transformers/models/hubert/modeling_hubert.py
\ No newline at end of file src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
src/transformers/models/sew/modeling_sew.py
src/transformers/models/sew_d/modeling_sew_d.py
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Style utils to preprocess files for doc tests.
The doc precossing function can be run on a list of files and/org
directories of files. It will recursively check if the files have
a python code snippet by looking for a ```python or ```py syntax.
In the default mode - `remove_new_line==False` the script will
add a new line before every python code ending ``` line to make
the docstrings ready for pytest doctests.
However, we don't want to have empty lines displayed in the
official documentation which is why the new line command can be
reversed by adding the flag `--remove_new_line` which sets
`remove_new_line==True`.
When debugging the doc tests locally, please make sure to
always run:
```python utils/prepare_for_doc_test.py src doc```
before running the doc tests:
```pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"```
Afterwards you should revert the changes by running
```python utils/prepare_for_doc_test.py src doc --remove_new_line```
"""
import argparse
import os
def process_code_block(code, add_new_line=True):
if add_new_line:
return maybe_append_new_line(code)
else:
return maybe_remove_new_line(code)
def maybe_append_new_line(code):
"""
Append new line if code snippet is a
Python code snippet
"""
lines = code.split("\n")
if lines[0] in ["py", "python"]:
# add new line before last line being ```
last_line = lines[-1]
lines.pop()
lines.append("\n" + last_line)
return "\n".join(lines)
def maybe_remove_new_line(code):
"""
Remove new line if code snippet is a
Python code snippet
"""
lines = code.split("\n")
if lines[0] in ["py", "python"]:
# add new line before last line being ```
lines = lines[:-2] + lines[-1:]
return "\n".join(lines)
def process_doc_file(code_file, add_new_line=True):
"""
Process given file.
Args:
code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
"""
with open(code_file, "r", encoding="utf-8", newline="\n") as f:
code = f.read()
# fmt: off
splits = code.split("```")
splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)]
clean_code = "```".join(splits)
# fmt: on
diff = clean_code != code
if diff:
print(f"Overwriting content of {code_file}.")
with open(code_file, "w", encoding="utf-8", newline="\n") as f:
f.write(clean_code)
def process_doc_files(*files, add_new_line=True):
"""
Applies doc styling or checks everything is correct in a list of files.
Args:
files (several `str` or `os.PathLike`): The files to treat.
Whether to restyle file or just check if they should be restyled.
Returns:
List[`str`]: The list of files changed or that should be restyled.
"""
for file in files:
# Treat folders
if os.path.isdir(file):
files = [os.path.join(file, f) for f in os.listdir(file)]
files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
process_doc_files(*files, add_new_line=add_new_line)
else:
try:
process_doc_file(file, add_new_line=add_new_line)
except Exception:
print(f"There is a problem in {file}.")
raise
def main(*files, add_new_line=True):
process_doc_files(*files, add_new_line=add_new_line)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
parser.add_argument(
"--remove_new_line",
action="store_true",
help="Whether to remove new line after each python code block instead of adding one.",
)
args = parser.parse_args()
main(*args.files, add_new_line=not args.remove_new_line)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment