Commit fb2f9538 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Improve MVDR stability (#2004)

Summary:
Division first, multiplication second. This helps avoid the value overflow issue. It also helps the ``stv_evd`` solution pass the gradient check.

Pull Request resolved: https://github.com/pytorch/audio/pull/2004

Reviewed By: mthrok

Differential Revision: D32539827

Pulled By: nateanl

fbshipit-source-id: 70a386608324bb6e1b1c7238c78d403698590f22
parent 3ff46bfa
......@@ -276,9 +276,8 @@ class AutogradTestMixin(TestBaseMixin):
@parameterized.expand([
"ref_channel",
# stv_power test time too long, comment for now
# stv_power and stv_evd test time too long, comment for now
# "stv_power",
# stv_evd will fail since the eigenvalues are not distinct
# "stv_evd",
])
def test_mvdr(self, solution):
......
......@@ -1829,7 +1829,7 @@ class MVDR(torch.nn.Module):
denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
# normalzie the numerator
scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
beamform_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
beamform_vector = numerator / (denominator.real.unsqueeze(-1) + eps) * scale
return beamform_vector
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment