Commit 2f4eb4ac authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix return dtype in MVDR module (#2376)

Summary:
Address https://github.com/pytorch/audio/issues/2375
The MVDR module internally transforms the dtype of complex tensors to `torch.complex128` for computation and transforms it back to the original dtype before returning the Tensor. However, it didn't convert back successfully due to `specgram_enhanced.to(dtype)`, which should be `specgram_enhanced = specgram_enhanced.to(dtype)`. Fix it to make the output dtype consistent with original input.

Pull Request resolved: https://github.com/pytorch/audio/pull/2376

Reviewed By: hwangjeff

Differential Revision: D36280851

Pulled By: nateanl

fbshipit-source-id: 553d1b98f899547209a4e3ebc59920c7ef1f3112
parent eab2f39d
......@@ -131,3 +131,20 @@ class TransformsTestBase(TestBaseMixin):
psd_np = psd_numpy(spectrogram.detach().numpy(), mask, multi_mask)
psd = transform(spectrogram, mask)
self.assertEqual(psd, psd_np, atol=1e-5, rtol=1e-5)
@parameterized.expand(
[
param(torch.complex64),
param(torch.complex128),
]
)
def test_mvdr(self, dtype):
"""Make sure the output dtype is the same as the input dtype"""
transform = T.MVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.5, n_channels=3)
specgram = get_spectrogram(waveform, n_fft=400) # (channel, freq, time)
specgram = specgram.to(dtype)
mask_s = torch.rand(specgram.shape[-2:])
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype
......@@ -2087,8 +2087,7 @@ class MVDR(torch.nn.Module):
# unpack batch
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])
specgram_enhanced.to(dtype)
return specgram_enhanced
return specgram_enhanced.to(dtype)
class RTFMVDR(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment