Unverified Commit b0513b01 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2 - MMS] Correct directly loading adapters weights (#24335)

* Correct direct lang loading

* correct more

* revert black

* Use tie weights instead=

* add tests

* add tests

* make style
parent e5c760d6
......@@ -1134,12 +1134,14 @@ class HubertModel(HubertPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForCTC(HubertPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.hubert = HubertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1152,15 +1154,29 @@ class HubertForCTC(HubertPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for Hubert so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -969,12 +969,14 @@ class SEWModel(SEWPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
class SEWForCTC(SEWPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.sew = SEWModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -987,15 +989,29 @@ class SEWForCTC(SEWPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for SEW so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1509,12 +1509,14 @@ class SEWDModel(SEWDPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
class SEWDForCTC(SEWDPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.sew_d = SEWDModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1527,15 +1529,29 @@ class SEWDForCTC(SEWDPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for SEWD so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1378,12 +1378,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
class UniSpeechForCTC(UniSpeechPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.unispeech = UniSpeechModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1396,15 +1398,29 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1385,12 +1385,14 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.unispeech_sat = UniSpeechSatModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1403,15 +1405,29 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for UniSpeechSat so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, UniSpeechSat never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1207,7 +1207,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(self, Wav2Vec2ForCTC):
self._init_weights(self.lm_head)
def load_adapter(self, target_lang: str, **kwargs):
def load_adapter(self, target_lang: str, force_load=True, **kwargs):
r"""
Load a language adapter model from a pre-trained adapter model.
......@@ -1215,6 +1215,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
target_lang (`str`):
Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
adapter.<lang>.safetensors or adapter.<lang>.bin
force_load (`bool`, defaults to `True`):
Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
......@@ -1271,6 +1273,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if self.config.adapter_attn_dim is None:
raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
if target_lang == self.target_lang and not force_load:
logger.warn(f"Adapter weights are already set to {target_lang}.")
return
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
......@@ -1372,6 +1378,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)
# set target language corectly
self.target_lang = target_lang
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
......@@ -1854,12 +1863,14 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1872,15 +1883,29 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1610,6 +1610,8 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1622,13 +1624,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
......
......@@ -1272,12 +1272,14 @@ class WavLMModel(WavLMPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForCTC(WavLMPreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.wavlm = WavLMModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
......@@ -1290,15 +1292,29 @@ class WavLMForCTC(WavLMPreTrainedModel):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for WavLM so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing
self.post_init()
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
......
......@@ -1117,6 +1117,40 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
def test_load_and_set_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
)
def get_logits(model, input_features):
model = model.to(torch_device)
batch = processor(
input_features,
padding=True,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
)
with torch.no_grad():
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
return logits
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="it")
logits = get_logits(model, input_features)
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
model_2.load_adapter("it")
logits_2 = get_logits(model_2, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
def test_load_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment