Unverified Commit 768aa3d9 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Wav2Vec2 and Co] Update init tests for PT 2.1 (#26494)

parent b5ca8fcd
......@@ -419,6 +419,7 @@ class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"quantizer.weight_proj.weight",
]
......@@ -680,6 +681,7 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"quantizer.weight_proj.weight",
]
......
......@@ -421,6 +421,7 @@ class UniSpeechRobustModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
......
......@@ -471,6 +471,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
......@@ -682,6 +683,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
......
......@@ -625,6 +625,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
......
......@@ -424,6 +424,7 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"conv.parametrizations.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment