Commit 392a03c8 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Relax dtype for MVDR (#2024)

Summary:
Allow users to use `torch.cfloat` dtype input for MVDR module. It internally convert the spectrogram into `torch.cdouble` and output the tensor with the original dtype of the spectrogram.

Pull Request resolved: https://github.com/pytorch/audio/pull/2024

Reviewed By: carolineechen

Differential Revision: D32594051

Pulled By: nateanl

fbshipit-source-id: e32609ccdc881b36300d579c90daba41c9234b46
parent 358354aa
...@@ -203,7 +203,6 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -203,7 +203,6 @@ class TestTransforms(common_utils.TorchaudioTestCase):
]) ])
def test_MVDR(self, multi_mask): def test_MVDR(self, multi_mask):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.to(torch.double)
specgram = common_utils.get_spectrogram(waveform, n_fft=400) specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1]) specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
if multi_mask: if multi_mask:
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
...@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): ...@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
device = torch.device('cpu') device = torch.device('cpu')
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
@skipIfNoCuda @skipIfNoCuda
...@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): ...@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
@skipIfNoCuda @skipIfNoCuda
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
...@@ -157,24 +157,6 @@ class Transforms(TestBaseMixin): ...@@ -157,24 +157,6 @@ class Transforms(TestBaseMixin):
mask = torch.rand(spectrogram.shape[-2:], device=self.device) mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, mask) self._assert_consistency_complex(T.PSD(), spectrogram, mask)
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
class TransformsFloat64Only(TestBaseMixin):
@parameterized.expand([ @parameterized.expand([
["ref_channel", True], ["ref_channel", True],
["stv_evd", True], ["stv_evd", True],
...@@ -186,10 +168,25 @@ class TransformsFloat64Only(TestBaseMixin): ...@@ -186,10 +168,25 @@ class TransformsFloat64Only(TestBaseMixin):
def test_MVDR(self, solution, online): def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(device=self.device, dtype=torch.cdouble)
mask_s = torch.rand(spectrogram.shape[-2:], device=self.device) mask_s = torch.rand(spectrogram.shape[-2:], device=self.device)
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex( self._assert_consistency_complex(
T.MVDR(solution=solution, online=online), T.MVDR(solution=solution, online=online),
spectrogram, mask_s, mask_n spectrogram, mask_s, mask_n
) )
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
...@@ -1665,9 +1665,9 @@ class MVDR(torch.nn.Module): ...@@ -1665,9 +1665,9 @@ class MVDR(torch.nn.Module):
(Default: ``False``) (Default: ``False``)
Note: Note:
The MVDR Module requires the input STFT to be double precision (``torch.complex128`` or ``torch.cdouble``), To improve the numerical stability, the input spectrogram will be converted to double precision
to improve the numerical stability. You can downgrade the precision to ``torch.float`` after generating the (``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
enhanced waveform for ASR joint training. is converted to the dtype of the input spectrogram to be compatible with other modules.
Note: Note:
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
...@@ -1944,14 +1944,18 @@ class MVDR(torch.nn.Module): ...@@ -1944,14 +1944,18 @@ class MVDR(torch.nn.Module):
torch.Tensor: The single-channel STFT of the enhanced speech. torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension `(..., freq, time)` Tensor of dimension `(..., freq, time)`
""" """
dtype = specgram.dtype
if specgram.ndim < 3: if specgram.ndim < 3:
raise ValueError( raise ValueError(
f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}" f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}"
) )
if specgram.dtype != torch.cdouble: if not specgram.is_complex():
raise ValueError( raise ValueError(
f"The type of ``specgram`` tensor must be ``torch.cdouble``. Found: {specgram.dtype}" f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
Found: {specgram.dtype}"
) )
if specgram.dtype == torch.cfloat:
specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
if mask_n is None: if mask_n is None:
warnings.warn( warnings.warn(
...@@ -2006,4 +2010,5 @@ class MVDR(torch.nn.Module): ...@@ -2006,4 +2010,5 @@ class MVDR(torch.nn.Module):
# unpack batch # unpack batch
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:]) specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])
specgram_enhanced.to(dtype)
return specgram_enhanced return specgram_enhanced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment