Commit 392a03c8 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Relax dtype for MVDR (#2024)

Summary:
Allow users to use `torch.cfloat` dtype input for MVDR module. It internally convert the spectrogram into `torch.cdouble` and output the tensor with the original dtype of the spectrogram.

Pull Request resolved: https://github.com/pytorch/audio/pull/2024

Reviewed By: carolineechen

Differential Revision: D32594051

Pulled By: nateanl

fbshipit-source-id: e32609ccdc881b36300d579c90daba41c9234b46
parent 358354aa
......@@ -203,7 +203,6 @@ class TestTransforms(common_utils.TorchaudioTestCase):
])
def test_MVDR(self, multi_mask):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.to(torch.double)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
if multi_mask:
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
......@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
device = torch.device('cpu')
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
@skipIfNoCuda
......@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
@skipIfNoCuda
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
......@@ -157,24 +157,6 @@ class Transforms(TestBaseMixin):
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, mask)
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
class TransformsFloat64Only(TestBaseMixin):
@parameterized.expand([
["ref_channel", True],
["stv_evd", True],
......@@ -186,10 +168,25 @@ class TransformsFloat64Only(TestBaseMixin):
def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(device=self.device, dtype=torch.cdouble)
mask_s = torch.rand(spectrogram.shape[-2:], device=self.device)
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, mask_s, mask_n
)
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
......@@ -1665,9 +1665,9 @@ class MVDR(torch.nn.Module):
(Default: ``False``)
Note:
The MVDR Module requires the input STFT to be double precision (``torch.complex128`` or ``torch.cdouble``),
to improve the numerical stability. You can downgrade the precision to ``torch.float`` after generating the
enhanced waveform for ASR joint training.
To improve the numerical stability, the input spectrogram will be converted to double precision
(``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
is converted to the dtype of the input spectrogram to be compatible with other modules.
Note:
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
......@@ -1944,14 +1944,18 @@ class MVDR(torch.nn.Module):
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension `(..., freq, time)`
"""
dtype = specgram.dtype
if specgram.ndim < 3:
raise ValueError(
f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}"
)
if specgram.dtype != torch.cdouble:
if not specgram.is_complex():
raise ValueError(
f"The type of ``specgram`` tensor must be ``torch.cdouble``. Found: {specgram.dtype}"
f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
Found: {specgram.dtype}"
)
if specgram.dtype == torch.cfloat:
specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
if mask_n is None:
warnings.warn(
......@@ -2006,4 +2010,5 @@ class MVDR(torch.nn.Module):
# unpack batch
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])
specgram_enhanced.to(dtype)
return specgram_enhanced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment