Unverified Commit 48463ebb authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add Speaker Diarization and Verification heads (#14723)

* Models

* Squashed commit of the following:

commit 72278e1e931a16d0879acc77f65762f3364833d0
Author: anton-l <aglozhkov@gmail.com>
Date:   Fri Dec 10 21:45:08 2021 +0300

* Add unispeech heads

* Add sd/sv automodels

* Docs cleanup

* Fix docstrings

* rename xvector classes

* examples

* Tests cleanup

* Style

* Better checkpoints for tests

* leftover docs

* apply review suggestions

* Style + init tests

* Update unispeech-sat tdnn downsampling
parent 2e07180c
...@@ -181,6 +181,13 @@ AutoModelForAudioClassification ...@@ -181,6 +181,13 @@ AutoModelForAudioClassification
:members: :members:
AutoModelForAudioFrameClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForAudioFrameClassification
:members:
AutoModelForCTC AutoModelForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -195,6 +202,13 @@ AutoModelForSpeechSeq2Seq ...@@ -195,6 +202,13 @@ AutoModelForSpeechSeq2Seq
:members: :members:
AutoModelForAudioXVector
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForAudioXVector
:members:
AutoModelForObjectDetection AutoModelForObjectDetection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -85,6 +85,20 @@ UniSpeechSatForSequenceClassification ...@@ -85,6 +85,20 @@ UniSpeechSatForSequenceClassification
:members: forward :members: forward
UniSpeechSatForAudioFrameClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.UniSpeechSatForAudioFrameClassification
:members: forward
UniSpeechSatForXVector
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.UniSpeechSatForXVector
:members: forward
UniSpeechSatForPreTraining UniSpeechSatForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -114,6 +114,20 @@ Wav2Vec2ForSequenceClassification ...@@ -114,6 +114,20 @@ Wav2Vec2ForSequenceClassification
:members: forward :members: forward
Wav2Vec2ForAudioFrameClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForAudioFrameClassification
:members: forward
Wav2Vec2ForXVector
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForXVector
:members: forward
Wav2Vec2ForPreTraining Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -649,6 +649,8 @@ if is_torch_available(): ...@@ -649,6 +649,8 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForCTC", "AutoModelForCTC",
"AutoModelForImageClassification", "AutoModelForImageClassification",
...@@ -1325,9 +1327,11 @@ if is_torch_available(): ...@@ -1325,9 +1327,11 @@ if is_torch_available():
_import_structure["models.unispeech_sat"].extend( _import_structure["models.unispeech_sat"].extend(
[ [
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST", "UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechSatForAudioFrameClassification",
"UniSpeechSatForCTC", "UniSpeechSatForCTC",
"UniSpeechSatForPreTraining", "UniSpeechSatForPreTraining",
"UniSpeechSatForSequenceClassification", "UniSpeechSatForSequenceClassification",
"UniSpeechSatForXVector",
"UniSpeechSatModel", "UniSpeechSatModel",
"UniSpeechSatPreTrainedModel", "UniSpeechSatPreTrainedModel",
] ]
...@@ -1358,10 +1362,12 @@ if is_torch_available(): ...@@ -1358,10 +1362,12 @@ if is_torch_available():
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForAudioFrameClassification",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining", "Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification", "Wav2Vec2ForSequenceClassification",
"Wav2Vec2ForXVector",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -2603,6 +2609,8 @@ if TYPE_CHECKING: ...@@ -2603,6 +2609,8 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
AutoModelForAudioXVector,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForCTC, AutoModelForCTC,
AutoModelForImageClassification, AutoModelForImageClassification,
...@@ -3164,9 +3172,11 @@ if TYPE_CHECKING: ...@@ -3164,9 +3172,11 @@ if TYPE_CHECKING:
) )
from .models.unispeech_sat import ( from .models.unispeech_sat import (
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST, UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechSatForAudioFrameClassification,
UniSpeechSatForCTC, UniSpeechSatForCTC,
UniSpeechSatForPreTraining, UniSpeechSatForPreTraining,
UniSpeechSatForSequenceClassification, UniSpeechSatForSequenceClassification,
UniSpeechSatForXVector,
UniSpeechSatModel, UniSpeechSatModel,
UniSpeechSatPreTrainedModel, UniSpeechSatPreTrainedModel,
) )
...@@ -3191,10 +3201,12 @@ if TYPE_CHECKING: ...@@ -3191,10 +3201,12 @@ if TYPE_CHECKING:
) )
from .models.wav2vec2 import ( from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining, Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification, Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -1117,6 +1117,54 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r""" ...@@ -1117,6 +1117,54 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
""" """
PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
Example::
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
>>> logits = model(**inputs).logits
>>> probabilities = torch.sigmoid(logits[0])
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
>>> labels = (probabilities > 0.5).long()
"""
PT_SPEECH_XVECTOR_SAMPLE = r"""
Example::
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
>>> embeddings = model(**inputs).embeddings
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
>>> if similarity < threshold:
... print("Speakers are not the same!")
"""
PT_SAMPLE_DOCSTRINGS = { PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
...@@ -1128,6 +1176,8 @@ PT_SAMPLE_DOCSTRINGS = { ...@@ -1128,6 +1176,8 @@ PT_SAMPLE_DOCSTRINGS = {
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE, "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
"CTC": PT_SPEECH_CTC_SAMPLE, "CTC": PT_SPEECH_CTC_SAMPLE,
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE, "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
} }
...@@ -1419,6 +1469,10 @@ def add_code_sample_docstrings( ...@@ -1419,6 +1469,10 @@ def add_code_sample_docstrings(
code_sample = sample_docstrings["LMHead"] code_sample = sample_docstrings["LMHead"]
elif "CTC" in model_class: elif "CTC" in model_class:
code_sample = sample_docstrings["CTC"] code_sample = sample_docstrings["CTC"]
elif "AudioFrameClassification" in model_class:
code_sample = sample_docstrings["AudioFrameClassification"]
elif "XVector" in model_class and modality == "audio":
code_sample = sample_docstrings["AudioXVector"]
elif "Model" in model_class and modality == "audio": elif "Model" in model_class and modality == "audio":
code_sample = sample_docstrings["SpeechBaseModel"] code_sample = sample_docstrings["SpeechBaseModel"]
elif "Model" in model_class or "Encoder" in model_class: elif "Model" in model_class or "Encoder" in model_class:
......
...@@ -53,6 +53,8 @@ if is_torch_available(): ...@@ -53,6 +53,8 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForCTC", "AutoModelForCTC",
"AutoModelForImageClassification", "AutoModelForImageClassification",
...@@ -161,6 +163,8 @@ if TYPE_CHECKING: ...@@ -161,6 +163,8 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
AutoModelForAudioXVector,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForCTC, AutoModelForCTC,
AutoModelForImageClassification, AutoModelForImageClassification,
......
...@@ -538,6 +538,22 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( ...@@ -538,6 +538,22 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
]
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
...@@ -578,6 +594,10 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( ...@@ -578,6 +594,10 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
) )
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
class AutoModel(_BaseAutoModelClass): class AutoModel(_BaseAutoModelClass):
...@@ -726,6 +746,22 @@ AutoModelForSpeechSeq2Seq = auto_class_update( ...@@ -726,6 +746,22 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
) )
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
AutoModelForAudioFrameClassification = auto_class_update(
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)
class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
class AutoModelWithLMHead(_AutoModelWithLMHead): class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
...@@ -42,9 +42,9 @@ logger = logging.get_logger(__name__) ...@@ -42,9 +42,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "HubertConfig" _CONFIG_FOR_DOC = "HubertConfig"
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
...@@ -1182,7 +1182,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel): ...@@ -1182,7 +1182,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -38,9 +38,9 @@ logger = logging.get_logger(__name__) ...@@ -38,9 +38,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SEWConfig" _CONFIG_FOR_DOC = "SEWConfig"
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k" _CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k" _SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
...@@ -1067,7 +1067,7 @@ class SEWForSequenceClassification(SEWPreTrainedModel): ...@@ -1067,7 +1067,7 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -39,9 +39,9 @@ logger = logging.get_logger(__name__) ...@@ -39,9 +39,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SEWDConfig" _CONFIG_FOR_DOC = "SEWDConfig"
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k" _CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k" _SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
...@@ -1598,7 +1598,7 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel): ...@@ -1598,7 +1598,7 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -44,9 +44,9 @@ logger = logging.get_logger(__name__) ...@@ -44,9 +44,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "UniSpeechConfig" _CONFIG_FOR_DOC = "UniSpeechConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv" _CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv" _SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 2 _HIDDEN_STATES_START_POSITION = 2
...@@ -1481,7 +1481,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1481,7 +1481,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -27,9 +27,11 @@ _import_structure = { ...@@ -27,9 +27,11 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_unispeech_sat"] = [ _import_structure["modeling_unispeech_sat"] = [
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST", "UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechSatForAudioFrameClassification",
"UniSpeechSatForCTC", "UniSpeechSatForCTC",
"UniSpeechSatForPreTraining", "UniSpeechSatForPreTraining",
"UniSpeechSatForSequenceClassification", "UniSpeechSatForSequenceClassification",
"UniSpeechSatForXVector",
"UniSpeechSatModel", "UniSpeechSatModel",
"UniSpeechSatPreTrainedModel", "UniSpeechSatPreTrainedModel",
] ]
...@@ -40,9 +42,11 @@ if TYPE_CHECKING: ...@@ -40,9 +42,11 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_unispeech_sat import ( from .modeling_unispeech_sat import (
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST, UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechSatForAudioFrameClassification,
UniSpeechSatForCTC, UniSpeechSatForCTC,
UniSpeechSatForPreTraining, UniSpeechSatForPreTraining,
UniSpeechSatForSequenceClassification, UniSpeechSatForSequenceClassification,
UniSpeechSatForXVector,
UniSpeechSatModel, UniSpeechSatModel,
UniSpeechSatPreTrainedModel, UniSpeechSatPreTrainedModel,
) )
......
...@@ -153,6 +153,17 @@ class UniSpeechSatConfig(PretrainedConfig): ...@@ -153,6 +153,17 @@ class UniSpeechSatConfig(PretrainedConfig):
instance of :class:`~transformers.UniSpeechSatForSequenceClassification`. instance of :class:`~transformers.UniSpeechSatForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256): classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
Dimensionality of the `XVector` embedding vectors.
Example:: Example::
...@@ -213,6 +224,10 @@ class UniSpeechSatConfig(PretrainedConfig): ...@@ -213,6 +224,10 @@ class UniSpeechSatConfig(PretrainedConfig):
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False, use_weighted_layer_sum=False,
classifier_proj_size=256, classifier_proj_size=256,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
...@@ -246,7 +261,6 @@ class UniSpeechSatConfig(PretrainedConfig): ...@@ -246,7 +261,6 @@ class UniSpeechSatConfig(PretrainedConfig):
self.num_clusters = num_clusters self.num_clusters = num_clusters
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.use_weighted_layer_sum = use_weighted_layer_sum self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
if ( if (
(len(self.conv_stride) != self.num_feat_extract_layers) (len(self.conv_stride) != self.num_feat_extract_layers)
...@@ -282,3 +296,12 @@ class UniSpeechSatConfig(PretrainedConfig): ...@@ -282,3 +296,12 @@ class UniSpeechSatConfig(PretrainedConfig):
# ctc loss # ctc loss
self.ctc_loss_reduction = ctc_loss_reduction self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity self.ctc_zero_infinity = ctc_zero_infinity
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
# XVector-specific parameters. Feel free to ignore for other classes.
self.tdnn_dim = list(tdnn_dim)
self.tdnn_kernel = list(tdnn_kernel)
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""
import argparse
import torch
from transformers import (
UniSpeechSatConfig,
UniSpeechSatForAudioFrameClassification,
UniSpeechSatForSequenceClassification,
UniSpeechSatForXVector,
Wav2Vec2FeatureExtractor,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def convert_classification(base_model_name, hf_config, downstream_dict):
model = UniSpeechSatForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["projector.weight"]
model.projector.bias.data = downstream_dict["projector.bias"]
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
return model
def convert_diarization(base_model_name, hf_config, downstream_dict):
model = UniSpeechSatForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
model.classifier.weight.data = downstream_dict["model.linear.weight"]
model.classifier.bias.data = downstream_dict["model.linear.bias"]
return model
def convert_xvector(base_model_name, hf_config, downstream_dict):
model = UniSpeechSatForXVector.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["connector.weight"]
model.projector.bias.data = downstream_dict["connector.bias"]
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
model.tdnn[i].kernel.weight.data = downstream_dict[
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
]
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
model.objective.weight.data = downstream_dict["objective.W"]
return model
@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
downstream_dict = checkpoint["Downstream"]
hf_config = UniSpeechSatConfig.from_pretrained(config_path)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)
arch = hf_config.architectures[0]
if arch.endswith("ForSequenceClassification"):
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForAudioFrameClassification"):
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForXVector"):
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
else:
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
if hf_config.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
...@@ -33,7 +33,7 @@ from ...file_utils import ( ...@@ -33,7 +33,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_unispeech_sat import UniSpeechSatConfig from .configuration_unispeech_sat import UniSpeechSatConfig
...@@ -45,9 +45,11 @@ logger = logging.get_logger(__name__) ...@@ -45,9 +45,11 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "UniSpeechSatConfig" _CONFIG_FOR_DOC = "UniSpeechSatConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus" _CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus" _SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_HIDDEN_STATES_START_POSITION = 2 _HIDDEN_STATES_START_POSITION = 2
...@@ -123,6 +125,38 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput): ...@@ -123,6 +125,38 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
...@@ -1472,7 +1506,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1472,7 +1506,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1538,3 +1572,285 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1538,3 +1572,285 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.unispeech_sat = UniSpeechSatModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech_sat.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.unispeech_sat(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
self.loss = nn.CrossEntropyLoss()
def forward(self, hidden_states, labels):
labels = labels.flatten()
weight = nn.functional.normalize(self.weight, dim=0)
hidden_states = nn.functional.normalize(hidden_states, dim=1)
cos_theta = torch.mm(hidden_states, weight)
psi = cos_theta - self.margin
onehot = nn.functional.one_hot(labels, self.num_labels)
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
loss = self.loss(logits, labels)
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
self.out_conv_dim = config.tdnn_dim[layer_id]
self.kernel_size = config.tdnn_kernel[layer_id]
self.dilation = config.tdnn_dilation[layer_id]
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()
def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@add_start_docstrings(
"""
UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.unispeech_sat = UniSpeechSatModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
self.tdnn = nn.ModuleList(tdnn_layers)
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech_sat.parameters():
param.requires_grad = False
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size in self.config.tdnn_kernel:
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
return input_lengths
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.unispeech_sat(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
for tdnn_layer in self.tdnn:
hidden_states = tdnn_layer(hidden_states)
# Statistic Pooling
if attention_mask is None:
mean_features = hidden_states.mean(dim=1)
std_features = hidden_states.std(dim=1)
else:
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
mean_features = []
std_features = []
for i, length in enumerate(tdnn_output_lengths):
mean_features.append(hidden_states[i, :length].mean(dim=0))
std_features.append(hidden_states[i, :length].std(dim=0))
mean_features = torch.stack(mean_features)
std_features = torch.stack(std_features)
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
output_embeddings = self.feature_extractor(statistic_pooling)
logits = self.classifier(output_embeddings)
loss = None
if labels is not None:
loss = self.objective(logits, labels)
if not return_dict:
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return XVectorOutput(
loss=loss,
logits=logits,
embeddings=output_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -31,10 +31,12 @@ _import_structure = { ...@@ -31,10 +31,12 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_wav2vec2"] = [ _import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForAudioFrameClassification",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining", "Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification", "Wav2Vec2ForSequenceClassification",
"Wav2Vec2ForXVector",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -65,10 +67,12 @@ if TYPE_CHECKING: ...@@ -65,10 +67,12 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_wav2vec2 import ( from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining, Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification, Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
......
...@@ -80,10 +80,10 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -80,10 +80,10 @@ class Wav2Vec2Config(PretrainedConfig):
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`): conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`. of `conv_stride` defines the number of convolutional layers and has to match the length of `conv_dim`.
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`): conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of length of `conv_kernel` defines the number of convolutional layers and has to match the length of
`conv_dim`. `conv_dim`.
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the 1D convolutional layers have a bias. Whether the 1D convolutional layers have a bias.
...@@ -153,6 +153,17 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -153,6 +153,17 @@ class Wav2Vec2Config(PretrainedConfig):
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256): classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
Dimensionality of the `XVector` embedding vectors.
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`): add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
warm-starting Wav2Vec2 for SpeechEncoderDecoder models. warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
...@@ -226,6 +237,10 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -226,6 +237,10 @@ class Wav2Vec2Config(PretrainedConfig):
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False, use_weighted_layer_sum=False,
classifier_proj_size=256, classifier_proj_size=256,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
...@@ -262,7 +277,6 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -262,7 +277,6 @@ class Wav2Vec2Config(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.use_weighted_layer_sum = use_weighted_layer_sum self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
if ( if (
(len(self.conv_stride) != self.num_feat_extract_layers) (len(self.conv_stride) != self.num_feat_extract_layers)
...@@ -305,3 +319,12 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -305,3 +319,12 @@ class Wav2Vec2Config(PretrainedConfig):
self.adapter_stride = adapter_stride self.adapter_stride = adapter_stride
self.num_adapter_layers = num_adapter_layers self.num_adapter_layers = num_adapter_layers
self.output_hidden_size = output_hidden_size or hidden_size self.output_hidden_size = output_hidden_size or hidden_size
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
# XVector-specific parameters. Feel free to ignore for other classes.
self.tdnn_dim = list(tdnn_dim)
self.tdnn_kernel = list(tdnn_kernel)
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim
...@@ -19,13 +19,52 @@ import argparse ...@@ -19,13 +19,52 @@ import argparse
import torch import torch
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging from transformers import (
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
logging,
)
logging.set_verbosity_info() logging.set_verbosity_info()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SUPPORTED_MODELS = ["UtteranceLevel"]
def convert_classification(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["projector.weight"]
model.projector.bias.data = downstream_dict["projector.bias"]
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
return model
def convert_diarization(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
model.classifier.weight.data = downstream_dict["model.linear.weight"]
model.classifier.bias.data = downstream_dict["model.linear.bias"]
return model
def convert_xvector(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["connector.weight"]
model.projector.bias.data = downstream_dict["connector.bias"]
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
model.tdnn[i].kernel.weight.data = downstream_dict[
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
]
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
model.objective.weight.data = downstream_dict["objective.W"]
return model
@torch.no_grad() @torch.no_grad()
...@@ -34,25 +73,27 @@ def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, mode ...@@ -34,25 +73,27 @@ def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, mode
Copy/paste/tweak model's weights to transformers design. Copy/paste/tweak model's weights to transformers design.
""" """
checkpoint = torch.load(checkpoint_path, map_location="cpu") checkpoint = torch.load(checkpoint_path, map_location="cpu")
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
downstream_dict = checkpoint["Downstream"] downstream_dict = checkpoint["Downstream"]
hf_congfig = Wav2Vec2Config.from_pretrained(config_path) hf_config = Wav2Vec2Config.from_pretrained(config_path)
hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False base_model_name, return_attention_mask=True, do_normalize=False
) )
if hf_congfig.use_weighted_layer_sum: arch = hf_config.architectures[0]
if arch.endswith("ForSequenceClassification"):
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForAudioFrameClassification"):
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForXVector"):
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
else:
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
if hf_config.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_model.projector.weight.data = downstream_dict["projector.weight"]
hf_model.projector.bias.data = downstream_dict["projector.bias"]
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
hf_feature_extractor.save_pretrained(model_dump_path) hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path) hf_model.save_pretrained(model_dump_path)
......
...@@ -34,7 +34,13 @@ from ...file_utils import ( ...@@ -34,7 +34,13 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config from .configuration_wav2vec2 import Wav2Vec2Config
...@@ -45,9 +51,11 @@ logger = logging.get_logger(__name__) ...@@ -45,9 +51,11 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Wav2Vec2Config" _CONFIG_FOR_DOC = "Wav2Vec2Config"
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks" _SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd"
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv"
_HIDDEN_STATES_START_POSITION = 2 _HIDDEN_STATES_START_POSITION = 2
...@@ -93,7 +101,7 @@ class Wav2Vec2BaseModelOutput(ModelOutput): ...@@ -93,7 +101,7 @@ class Wav2Vec2BaseModelOutput(ModelOutput):
@dataclass @dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput): class Wav2Vec2ForPreTrainingOutput(ModelOutput):
""" """
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions. Output type of :class:`~transformers.Wav2Vec2ForPreTraining`, with potential hidden states and attentions.
Args: Args:
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`): loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
...@@ -132,6 +140,38 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): ...@@ -132,6 +140,38 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None diversity_loss: Optional[torch.FloatTensor] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
mask_prob: float, mask_prob: float,
...@@ -1707,7 +1747,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): ...@@ -1707,7 +1747,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1773,3 +1813,281 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): ...@@ -1773,3 +1813,281 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
self.loss = nn.CrossEntropyLoss()
def forward(self, hidden_states, labels):
labels = labels.flatten()
weight = nn.functional.normalize(self.weight, dim=0)
hidden_states = nn.functional.normalize(hidden_states, dim=1)
cos_theta = torch.mm(hidden_states, weight)
psi = cos_theta - self.margin
onehot = nn.functional.one_hot(labels, self.num_labels)
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
loss = self.loss(logits, labels)
return loss
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
self.out_conv_dim = config.tdnn_dim[layer_id]
self.kernel_size = config.tdnn_kernel[layer_id]
self.dilation = config.tdnn_dilation[layer_id]
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()
def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@add_start_docstrings(
"""
Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
self.tdnn = nn.ModuleList(tdnn_layers)
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2.parameters():
param.requires_grad = False
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size in self.config.tdnn_kernel:
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
return input_lengths
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
for tdnn_layer in self.tdnn:
hidden_states = tdnn_layer(hidden_states)
# Statistic Pooling
if attention_mask is None:
mean_features = hidden_states.mean(dim=1)
std_features = hidden_states.std(dim=1)
else:
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
mean_features = []
std_features = []
for i, length in enumerate(tdnn_output_lengths):
mean_features.append(hidden_states[i, :length].mean(dim=0))
std_features.append(hidden_states[i, :length].std(dim=0))
mean_features = torch.stack(mean_features)
std_features = torch.stack(std_features)
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
output_embeddings = self.feature_extractor(statistic_pooling)
logits = self.classifier(output_embeddings)
loss = None
if labels is not None:
loss = self.objective(logits, labels)
if not return_dict:
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return XVectorOutput(
loss=loss,
logits=logits,
embeddings=output_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -422,6 +422,30 @@ class AutoModelForAudioClassification: ...@@ -422,6 +422,30 @@ class AutoModelForAudioClassification:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForAudioFrameClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForAudioXVector:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForCausalLM: class AutoModelForCausalLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -4896,6 +4920,11 @@ class UniSpeechPreTrainedModel: ...@@ -4896,6 +4920,11 @@ class UniSpeechPreTrainedModel:
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class UniSpeechSatForAudioFrameClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UniSpeechSatForCTC: class UniSpeechSatForCTC:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -4918,6 +4947,11 @@ class UniSpeechSatForSequenceClassification: ...@@ -4918,6 +4947,11 @@ class UniSpeechSatForSequenceClassification:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class UniSpeechSatForXVector:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UniSpeechSatModel: class UniSpeechSatModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -5072,6 +5106,11 @@ class ViTPreTrainedModel: ...@@ -5072,6 +5106,11 @@ class ViTPreTrainedModel:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Wav2Vec2ForAudioFrameClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2ForCTC: class Wav2Vec2ForCTC:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -5106,6 +5145,11 @@ class Wav2Vec2ForSequenceClassification: ...@@ -5106,6 +5145,11 @@ class Wav2Vec2ForSequenceClassification:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Wav2Vec2ForXVector:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2Model: class Wav2Vec2Model:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment