Unverified Commit b00bacf7 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Add normalization to steering vector solutions in MVDR Module (#1765)

parent ddb04e7d
......@@ -272,7 +272,9 @@ class MVDR(torch.nn.Module):
numerator = torch.linalg.solve(psd_n, stv).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
beamform_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
# normalzie the numerator
scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
beamform_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
return beamform_vector
......@@ -315,8 +317,7 @@ class MVDR(torch.nn.Module):
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
stv = stv.unsqueeze(-1)
for _ in range(3):
stv = torch.matmul(phi, stv)
stv = torch.matmul(phi, stv)
stv = torch.matmul(psd_s, stv)
return stv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment