Commit f0bc00c9 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Remove possible manual seeds from test files. (#2436)

Summary:
For test files where applicable, removed manual seeds where applicable. Refactoring https://github.com/pytorch/audio/issues/2267

Pull Request resolved: https://github.com/pytorch/audio/pull/2436

Reviewed By: carolineechen

Differential Revision: D36896854

Pulled By: skim0514

fbshipit-source-id: 7b4dd8a8dbfbef271f5cc56564dc83a760407e6c
parent b68864ca
...@@ -55,7 +55,6 @@ class TestLibriSpeechRNNTModule(TorchaudioTestCase): ...@@ -55,7 +55,6 @@ class TestLibriSpeechRNNTModule(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -49,7 +49,6 @@ class TestMuSTCRNNTModule(TorchaudioTestCase): ...@@ -49,7 +49,6 @@ class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -53,7 +53,6 @@ class TestTEDLIUM3RNNTModule(TorchaudioTestCase): ...@@ -53,7 +53,6 @@ class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -9,7 +9,6 @@ class TestCropAudioLabel(TorchaudioTestCase): ...@@ -9,7 +9,6 @@ class TestCropAudioLabel(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -232,28 +232,23 @@ class Autograd(TestBaseMixin): ...@@ -232,28 +232,23 @@ class Autograd(TestBaseMixin):
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q)) self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
def test_deemph_biquad(self): def test_deemph_biquad(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
self.assert_grad(F.deemph_biquad, (x, 44100)) self.assert_grad(F.deemph_biquad, (x, 44100))
def test_flanger(self): def test_flanger(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.flanger, (x, 44100)) self.assert_grad(F.flanger, (x, 44100))
def test_gain(self): def test_gain(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.gain, (x, 1.1)) self.assert_grad(F.gain, (x, 1.1))
def test_overdrive(self): def test_overdrive(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.gain, (x,)) self.assert_grad(F.gain, (x,))
@parameterized.expand([(True,), (False,)]) @parameterized.expand([(True,), (False,)])
def test_phaser(self, sinusoidal): def test_phaser(self, sinusoidal):
torch.random.manual_seed(2434)
sr = 8000 sr = 8000
x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
self.assert_grad(F.phaser, (x, sr, sinusoidal)) self.assert_grad(F.phaser, (x, sr, sinusoidal))
......
...@@ -52,7 +52,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -52,7 +52,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
momentum = 0.99 momentum = 0.99
n_iter = 32 n_iter = 32
length = 1000 length = 1000
torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6) batch = torch.rand(self.batch_size, 1, 201, 6)
kwargs = { kwargs = {
"window": window, "window": window,
...@@ -80,7 +79,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -80,7 +79,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_detect_pitch_frequency(self, sample_rate, n_channels): def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a # Use different frequencies to ensure each item in the batch returns a
# different answer. # different answer.
torch.manual_seed(0)
frequencies = torch.randint(100, 1000, [self.batch_size]) frequencies = torch.randint(100, 1000, [self.batch_size])
waveforms = torch.stack( waveforms = torch.stack(
[ [
...@@ -103,7 +101,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -103,7 +101,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
] ]
) )
def test_amplitude_to_DB(self, top_db): def test_amplitude_to_DB(self, top_db):
torch.manual_seed(0)
spec = torch.rand(self.batch_size, 2, 100, 100) * 200 spec = torch.rand(self.batch_size, 2, 100, 100) * 200
amplitude_mult = 20.0 amplitude_mult = 20.0
...@@ -137,7 +134,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -137,7 +134,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
top_db = 20.0 top_db = 20.0
# Make a batch of noise # Make a batch of noise
torch.manual_seed(0)
spec = torch.rand([2, 2, 100, 100]) * 200 spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other # Make one item blow out the other
spec[0] += 50 spec[0] += 50
...@@ -158,7 +154,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -158,7 +154,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
top_db = 40.0 top_db = 40.0
torch.manual_seed(0)
spec = torch.rand([1, 2, 100, 100]) * 200 spec = torch.rand([1, 2, 100, 100]) * 200
# Make one channel blow out the other # Make one channel blow out the other
spec[:, 0] += 50 spec[:, 0] += 50
...@@ -173,7 +168,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -173,7 +168,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
assert (difference >= 1e-5).any() assert (difference >= 1e-5).any()
def test_contrast(self): def test_contrast(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
kwargs = { kwargs = {
"enhancement_amount": 80.0, "enhancement_amount": 80.0,
...@@ -182,7 +176,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -182,7 +176,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, inputs=(waveforms,)) self.assert_batch_consistency(func, inputs=(waveforms,))
def test_dcshift(self): def test_dcshift(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
kwargs = { kwargs = {
"shift": 0.5, "shift": 0.5,
...@@ -192,7 +185,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -192,7 +185,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, inputs=(waveforms,)) self.assert_batch_consistency(func, inputs=(waveforms,))
def test_overdrive(self): def test_overdrive(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
kwargs = { kwargs = {
"gain": 45, "gain": 45,
...@@ -215,7 +207,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -215,7 +207,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, inputs=(batch,)) self.assert_batch_consistency(func, inputs=(batch,))
def test_flanger(self): def test_flanger(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100 sample_rate = 44100
kwargs = { kwargs = {
...@@ -234,7 +225,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -234,7 +225,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
name_func=_name_from_args, name_func=_name_from_args,
) )
def test_sliding_window_cmn(self, center, norm_vars): def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200 spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
kwargs = { kwargs = {
"center": center, "center": center,
...@@ -281,7 +271,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -281,7 +271,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_lfilter(self): def test_lfilter(self):
signal_length = 2048 signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3) a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
...@@ -289,7 +278,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -289,7 +278,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_filtfilt(self): def test_filtfilt(self):
signal_length = 2048 signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3) a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
...@@ -319,7 +307,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -319,7 +307,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(F.psd, (specgram, mask)) self.assert_batch_consistency(F.psd, (specgram, mask))
def test_mvdr_weights_souden(self): def test_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
batch_size = 2 batch_size = 2
channel = 4 channel = 4
n_fft_bin = 10 n_fft_bin = 10
...@@ -332,7 +319,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -332,7 +319,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, (psd_noise, psd_speech)) self.assert_batch_consistency(func, (psd_noise, psd_speech))
def test_mvdr_weights_souden_with_tensor(self): def test_mvdr_weights_souden_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2 batch_size = 2
channel = 4 channel = 4
n_fft_bin = 10 n_fft_bin = 10
...@@ -343,7 +329,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -343,7 +329,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(F.mvdr_weights_souden, (psd_noise, psd_speech, reference_channel)) self.assert_batch_consistency(F.mvdr_weights_souden, (psd_noise, psd_speech, reference_channel))
def test_mvdr_weights_rtf(self): def test_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
batch_size = 2 batch_size = 2
channel = 4 channel = 4
n_fft_bin = 129 n_fft_bin = 129
...@@ -356,7 +341,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -356,7 +341,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, (rtf, psd_noise)) self.assert_batch_consistency(func, (rtf, psd_noise))
def test_mvdr_weights_rtf_with_tensor(self): def test_mvdr_weights_rtf_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2 batch_size = 2
channel = 4 channel = 4
n_fft_bin = 129 n_fft_bin = 129
...@@ -367,7 +351,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -367,7 +351,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel)) self.assert_batch_consistency(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
def test_rtf_evd(self): def test_rtf_evd(self):
torch.random.manual_seed(2434)
batch_size = 2 batch_size = 2
channel = 4 channel = 4
n_fft_bin = 5 n_fft_bin = 5
...@@ -382,7 +365,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -382,7 +365,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
] ]
) )
def test_rtf_power(self, n_iter): def test_rtf_power(self, n_iter):
torch.random.manual_seed(2434)
channel = 4 channel = 4
batch_size = 2 batch_size = 2
n_fft_bin = 10 n_fft_bin = 10
...@@ -402,7 +384,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -402,7 +384,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
] ]
) )
def test_rtf_power_with_tensor(self, n_iter): def test_rtf_power_with_tensor(self, n_iter):
torch.random.manual_seed(2434)
channel = 4 channel = 4
batch_size = 2 batch_size = 2
n_fft_bin = 10 n_fft_bin = 10
...@@ -417,7 +398,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -417,7 +398,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel)) self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel))
def test_apply_beamforming(self): def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000 sr = 8000
n_fft = 400 n_fft = 400
batch_size, num_channels = 2, 3 batch_size, num_channels = 2, 3
......
...@@ -35,7 +35,6 @@ class TestApplyCodec(TorchaudioTestCase): ...@@ -35,7 +35,6 @@ class TestApplyCodec(TorchaudioTestCase):
The purpose of this test suite is to verify that apply_codec functionalities do not exhibit The purpose of this test suite is to verify that apply_codec functionalities do not exhibit
abnormal behaviors. abnormal behaviors.
""" """
torch.random.manual_seed(42)
sample_rate = 8000 sample_rate = 8000
num_frames = 3 * sample_rate num_frames = 3 * sample_rate
num_channels = 2 num_channels = 2
......
...@@ -131,7 +131,6 @@ class FunctionalComplex(TestBaseMixin): ...@@ -131,7 +131,6 @@ class FunctionalComplex(TestBaseMixin):
hop_length = 256 hop_length = 256
num_freq = 1025 num_freq = 1025
num_frames = 400 num_frames = 400
torch.random.manual_seed(42)
# Due to cummulative sum, numerical error in using torch.float32 will # Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not # result in bottom right values of the stretched sectrogram to not
......
...@@ -269,7 +269,6 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -269,7 +269,6 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(F.lfilter, (waveform, a_coeffs, b_coeffs, True, True)) self._assert_consistency(F.lfilter, (waveform, a_coeffs, b_coeffs, True, True))
def test_filtfilt(self): def test_filtfilt(self):
torch.manual_seed(296)
waveform = common_utils.get_whitenoise(sample_rate=8000) waveform = common_utils.get_whitenoise(sample_rate=8000)
b_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype) b_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype)
a_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype) a_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype)
...@@ -531,7 +530,6 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -531,7 +530,6 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, (waveform,)) self._assert_consistency(func, (waveform,))
def test_flanger(self): def test_flanger(self):
torch.random.manual_seed(40)
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
def func(tensor): def func(tensor):
......
...@@ -26,7 +26,6 @@ class ConformerTestImpl(TestBaseMixin): ...@@ -26,7 +26,6 @@ class ConformerTestImpl(TestBaseMixin):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
torch.random.manual_seed(31)
def test_torchscript_consistency_forward(self): def test_torchscript_consistency_forward(self):
r"""Verify that scripting Conformer does not change the behavior of method `forward`.""" r"""Verify that scripting Conformer does not change the behavior of method `forward`."""
......
...@@ -9,7 +9,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,7 +9,6 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
) )
NUM_TOKENS = 8 NUM_TOKENS = 8
...@@ -38,7 +37,6 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -38,7 +37,6 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_emissions(self): def _get_emissions(self):
B, T, N = 4, 15, NUM_TOKENS B, T, N = 4, 15, NUM_TOKENS
torch.manual_seed(0)
emissions = torch.rand(B, T, N) emissions = torch.rand(B, T, N)
return emissions return emissions
......
...@@ -54,7 +54,6 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -54,7 +54,6 @@ class RNNTTestImpl(TestBaseMixin):
input_dim = input_config["input_dim"] input_dim = input_config["input_dim"]
right_context_length = input_config["right_context_length"] right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31)
input = torch.rand(batch_size, max_input_length + right_context_length, input_dim).to( input = torch.rand(batch_size, max_input_length + right_context_length, input_dim).to(
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
...@@ -68,7 +67,6 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -68,7 +67,6 @@ class RNNTTestImpl(TestBaseMixin):
input_dim = input_config["input_dim"] input_dim = input_config["input_dim"]
right_context_length = input_config["right_context_length"] right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31)
input = torch.rand(batch_size, segment_length + right_context_length, input_dim).to( input = torch.rand(batch_size, segment_length + right_context_length, input_dim).to(
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
...@@ -83,7 +81,6 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -83,7 +81,6 @@ class RNNTTestImpl(TestBaseMixin):
num_symbols = input_config["num_symbols"] num_symbols = input_config["num_symbols"]
max_target_length = input_config["max_target_length"] max_target_length = input_config["max_target_length"]
torch.random.manual_seed(31)
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32) input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32)
lengths = torch.randint(1, max_target_length + 1, (batch_size,)).to(device=self.device, dtype=torch.int32) lengths = torch.randint(1, max_target_length + 1, (batch_size,)).to(device=self.device, dtype=torch.int32)
return input, lengths return input, lengths
...@@ -95,7 +92,6 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -95,7 +92,6 @@ class RNNTTestImpl(TestBaseMixin):
max_target_length = input_config["max_target_length"] max_target_length = input_config["max_target_length"]
input_dim = input_config["encoding_dim"] input_dim = input_config["encoding_dim"]
torch.random.manual_seed(31)
utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to( utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to(
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
......
...@@ -46,8 +46,6 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -46,8 +46,6 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
def test_torchscript_consistency_forward(self): def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `forward`.""" r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `forward`."""
torch.random.manual_seed(31)
input_config = self._get_input_config() input_config = self._get_input_config()
batch_size = input_config["batch_size"] batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"] max_input_length = input_config["max_input_length"]
...@@ -74,8 +72,6 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -74,8 +72,6 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
def test_torchscript_consistency_infer(self): def test_torchscript_consistency_infer(self):
r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `infer`.""" r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `infer`."""
torch.random.manual_seed(31)
input_config = self._get_input_config() input_config = self._get_input_config()
segment_length = input_config["segment_length"] segment_length = input_config["segment_length"]
right_context_length = input_config["right_context_length"] right_context_length = input_config["right_context_length"]
......
...@@ -134,7 +134,6 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -134,7 +134,6 @@ class TestFairseqIntegration(TorchaudioTestCase):
"""Wav2vec2 pretraining models from fairseq can be imported and yields the same results""" """Wav2vec2 pretraining models from fairseq can be imported and yields the same results"""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
torch.manual_seed(0)
original = self._get_model(config).eval() original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval() imported = import_fairseq_model(original).eval()
...@@ -149,7 +148,6 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -149,7 +148,6 @@ class TestFairseqIntegration(TorchaudioTestCase):
"""HuBERT pretraining models from fairseq can be imported and yields the same results""" """HuBERT pretraining models from fairseq can be imported and yields the same results"""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
torch.manual_seed(0)
original = self._get_model(config).eval() original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval() imported = import_fairseq_model(original).eval()
...@@ -241,7 +239,6 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -241,7 +239,6 @@ class TestFairseqIntegration(TorchaudioTestCase):
reloaded.eval() reloaded.eval()
# Without mask # Without mask
torch.manual_seed(0)
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
ref, _ = imported(x) ref, _ = imported(x)
hyp, _ = reloaded(x) hyp, _ = reloaded(x)
......
...@@ -89,7 +89,6 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -89,7 +89,6 @@ class TestHFIntegration(TorchaudioTestCase):
raise ValueError(f'Unexpected arch: {config["architectures"]}') raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config): def _test_import_pretrain(self, original, imported, config):
torch.manual_seed(0)
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
ref = original.feature_extractor(x).transpose(1, 2) ref = original.feature_extractor(x).transpose(1, 2)
...@@ -173,7 +172,6 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -173,7 +172,6 @@ class TestHFIntegration(TorchaudioTestCase):
self._test_import_finetune(original, imported, config) self._test_import_finetune(original, imported, config)
def _test_recreate(self, imported, reloaded, config): def _test_recreate(self, imported, reloaded, config):
torch.manual_seed(0)
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
ref, _ = imported.feature_extractor(x, None) ref, _ = imported.feature_extractor(x, None)
......
...@@ -48,7 +48,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -48,7 +48,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
model = model.to(device=device, dtype=dtype) model = model.to(device=device, dtype=dtype)
model = model.eval() model = model.eval()
torch.manual_seed(0)
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype) waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
...@@ -84,7 +83,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -84,7 +83,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
model.eval() model.eval()
num_layers = len(model.encoder.transformer.layers) num_layers = len(model.encoder.transformer.layers)
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
...@@ -119,7 +117,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -119,7 +117,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def _test_batch_consistency(self, model): def _test_batch_consistency(self, model):
model.eval() model.eval()
batch_size, max_frames = 5, 5 * 1024 batch_size, max_frames = 5, 5 * 1024
torch.manual_seed(0)
waveforms = torch.randn(batch_size, max_frames) waveforms = torch.randn(batch_size, max_frames)
input_lengths = torch.tensor([i * 3200 for i in range(1, 6)]) input_lengths = torch.tensor([i * 3200 for i in range(1, 6)])
...@@ -148,7 +145,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -148,7 +145,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def _test_zero_length(self, model): def _test_zero_length(self, model):
model.eval() model.eval()
torch.manual_seed(0)
batch_size = 3 batch_size = 3
waveforms = torch.randn(batch_size, 1024) waveforms = torch.randn(batch_size, 1024)
input_lengths = torch.zeros(batch_size) input_lengths = torch.zeros(batch_size)
...@@ -172,7 +168,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -172,7 +168,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
...@@ -220,7 +215,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -220,7 +215,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module." assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
...@@ -250,7 +244,6 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -250,7 +244,6 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module." assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
......
...@@ -44,7 +44,6 @@ class TransformsTestBase(TestBaseMixin): ...@@ -44,7 +44,6 @@ class TransformsTestBase(TestBaseMixin):
# Run transform # Run transform
transform = T.InverseMelScale(n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype) transform = T.InverseMelScale(n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
torch.random.manual_seed(0)
result = transform(input) result = transform(input)
# Compare # Compare
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment