Unverified Commit 52972e70 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Fix torch srcipt (#24062)

* [Wav2Vec2] Fix torch srcipt

* fix more
parent 612b2a1a
...@@ -1178,8 +1178,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1178,8 +1178,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
@property def _get_adapters(self):
def _adapters(self):
if self.config.adapter_attn_dim is None: if self.config.adapter_attn_dim is None:
raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
...@@ -1339,7 +1338,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1339,7 +1338,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
f" directory containing a file named {filepath}." f" directory containing a file named {filepath}."
) )
adapter_weights = self._adapters adapter_weights = self._get_adapters()
unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys()) unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
missing_keys = set(adapter_weights.keys()) - set(state_dict.keys()) missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
......
...@@ -297,7 +297,7 @@ class Wav2Vec2ModelTester: ...@@ -297,7 +297,7 @@ class Wav2Vec2ModelTester:
config.adapter_attn_dim = 16 config.adapter_attn_dim = 16
model = Wav2Vec2ForCTC(config=config) model = Wav2Vec2ForCTC(config=config)
self.parent.assertIsNotNone(model._adapters) self.parent.assertIsNotNone(model._get_adapters())
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -1146,7 +1146,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -1146,7 +1146,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
model = Wav2Vec2ForCTC.from_pretrained(tempdir) model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features) logits = get_logits(model, input_features)
adapter_weights = model._adapters adapter_weights = model._get_adapters()
# save safe weights # save safe weights
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng")) safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
...@@ -1168,7 +1168,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -1168,7 +1168,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
model = Wav2Vec2ForCTC.from_pretrained(tempdir) model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features) logits = get_logits(model, input_features)
adapter_weights = model._adapters adapter_weights = model._get_adapters()
# save pt weights # save pt weights
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng")) pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment