mvdr.py 18.1 KB
Newer Older
nateanl's avatar
nateanl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""Implementation of MVDR Beamforming Module

Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py

We provide three solutions of MVDR beamforming. One is based on reference channel selection:
Souden, Mehrez, Jacob Benesty, and Sofiene Affes.
"On optimal frequency-domain multichannel linear filtering for noise reduction."
IEEE Transactions on audio, speech, and language processing 18.2 (2009): 260-276.

The other two solutions are based on the steering vector. We apply either eigenvalue decomposition
or the power method to get the steering vector from the PSD matrices.

For eigenvalue decomposistion method, please refer:
Higuchi, Takuya, et al. "Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise."
2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2016.

For power method, please refer:
Mises, R. V., and Hilda Pollaczek‐Geiringer.
"Praktische Verfahren der Gleichungsauflösung."
ZAMM‐Journal of Applied Mathematics and Mechanics/Zeitschrift für Angewandte Mathematik und Mechanik 9.1 (1929): 58-77.

For online streaming audio, we provide a recursive method to update PSD matrices based on:
Higuchi, Takuya, et al.
"Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR."
IEEE/ACM Transactions on Audio, Speech, and Language Processing 25.4 (2017): 780-793.
"""

from typing import Optional

import torch


def mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
    r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.

    Args:
        input (torch.Tensor): Tensor of dimension (..., channel, channel)
        dim1 (int, optional): the first dimension of the diagonal matrix
            (Default: -1)
        dim2 (int, optional): the second dimension of the diagonal matrix
            (Default: -2)

    Returns:
        torch.Tensor: trace of the input Tensor
    """
    assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
    assert input.shape[dim1] == input.shape[dim2],\
        "The size of ``dim1`` and ``dim2`` must be the same."
    input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
    return input.sum(dim=-1)


class PSD(torch.nn.Module):
    r"""Compute cross-channel power spectral density (PSD) matrix.

    Args:
        multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
        normalize (bool, optional): whether normalize the mask along the time dimension
        eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-15
    """

    def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
        super().__init__()
        self.multi_mask = multi_mask
        self.normalize = normalize
        self.eps = eps

    def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None):
        """
        Args:
            X (torch.Tensor): multi-channel complex-valued STFT matrix.
                Tensor of dimension (..., channel, freq, time)
            mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
                Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
                of dimension (..., channel, freq, time) if multi_mask is ``True``

        Returns:
            torch.Tensor: PSD matrix of the input STFT matrix.
                Tensor of dimension (..., freq, channel, channel)
        """
        # outer product:
        # (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
        psd_X = torch.einsum("...cft,...eft->...ftce", [X, X.conj()])

        if mask is None:
            psd = psd_X
        else:
            if self.multi_mask:
                # Averaging mask along channel dimension
                mask = mask.mean(dim=-3)  # (..., freq, time)

            # Normalized mask along time dimension:
            if self.normalize:
                mask = mask / (mask.sum(dim=-1, keepdim=True) + self.eps)

            psd = psd_X * mask.unsqueeze(-1).unsqueeze(-1)

        psd = psd.sum(dim=-3)
        return psd


class MVDR(torch.nn.Module):
    """MVDR module that performs MVDR beamforming with Time-Frequency masks.

    Args:
        ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
        solution (str, optional): the solution to get MVDR weight.
            Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
        multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
        diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise
            (Default: ``True``)
        diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading
            (Default: 1e-7)
        online (bool, optional): whether to update the mvdr vector based on the previous psd matrices.
            (Default: ``False``)

    Note:
        If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
        eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
    """

    def __init__(
        self,
        ref_channel: int = 0,
        solution: str = "ref_channel",
        multi_mask: bool = False,
        diag_loading: bool = True,
        diag_eps: float = 1e-7,
        online: bool = False,
    ):
        super().__init__()
        assert solution in ["ref_channel", "stv_evd", "stv_power"],\
            "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
        self.ref_channel = ref_channel
        self.solution = solution
        self.multi_mask = multi_mask
        self.diag_loading = diag_loading
        self.diag_eps = diag_eps
        self.online = online
        self.psd = PSD(multi_mask)

        psd_s: torch.Tensor = torch.zeros(1)
        psd_n: torch.Tensor = torch.zeros(1)
        mask_sum_s: torch.Tensor = torch.zeros(1)
        mask_sum_n: torch.Tensor = torch.zeros(1)
        self.register_buffer('psd_s', psd_s)
        self.register_buffer('psd_n', psd_n)
        self.register_buffer('mask_sum_s', mask_sum_s)
        self.register_buffer('mask_sum_n', mask_sum_n)

    def _get_updated_mvdr_vector(
        self,
        psd_s: torch.Tensor,
        psd_n: torch.Tensor,
        mask_s: torch.Tensor,
        mask_n: torch.Tensor,
        reference_vector: torch.Tensor,
        solution: str = 'ref_channel',
        diagonal_loading: bool = True,
        diag_eps: float = 1e-7,
        eps: float = 1e-8,
    ) -> torch.Tensor:
        r"""Recursively update the MVDR beamforming vector.

        Args:
            psd_s (torch.Tensor): psd matrix of target speech
            psd_n (torch.Tensor): psd matrix of noise
            mask_s (torch.Tensor): T-F mask of target speech
            mask_n (torch.Tensor): T-F mask of noise
            reference_vector (torch.Tensor): one-hot reference channel matrix
            solution (str, optional): the solution to estimate the beamforming weight
                (Default: ``ref_channel``)
            diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
                (Default: ``True``)
            diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
                (Default: 1e-7)
            eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)

        Returns:
            torch.Tensor: the mvdr beamforming weight matrix
        """
        if self.multi_mask:
            # Averaging mask along channel dimension
            mask_s = mask_s.mean(dim=-3)  # (..., freq, time)
            mask_n = mask_n.mean(dim=-3)  # (..., freq, time)
        if self.psd_s.ndim == 1:
            self.psd_s = psd_s
            self.psd_n = psd_n
            self.mask_sum_s = mask_s.sum(dim=-1)
            self.mask_sum_n = mask_n.sum(dim=-1)
            return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
        else:
            psd_s = self._get_updated_psd_speech(psd_s, mask_s)
            psd_n = self._get_updated_psd_noise(psd_n, mask_n)
            self.psd_s = psd_s
            self.psd_n = psd_n
            self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
            self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
            return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)

    def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
        r"""Update psd of speech recursively.

        Args:
            psd_s (torch.Tensor): psd matrix of target speech
            mask_s (torch.Tensor): T-F mask of target speech

        Returns:
            torch.Tensor: the updated psd of speech
        """
        numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
        denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
        psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None]
        return psd_s

    def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
        r"""Update psd of noise recursively.

        Args:
            psd_n (torch.Tensor): psd matrix of target noise
            mask_n (torch.Tensor): T-F mask of target noise

        Returns:
            torch.Tensor: the updated psd of noise
        """
        numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
        denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
        psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
        return psd_n

    def _get_mvdr_vector(
        self,
        psd_s: torch.Tensor,
        psd_n: torch.Tensor,
        reference_vector: torch.Tensor,
        solution: str = 'ref_channel',
        diagonal_loading: bool = True,
        diag_eps: float = 1e-7,
        eps: float = 1e-8,
    ) -> torch.Tensor:
        r"""Compute beamforming vector by the reference channel selection method.

        Args:
            psd_s (torch.Tensor): psd matrix of target speech
            psd_n (torch.Tensor): psd matrix of noise
            reference_vector (torch.Tensor): one-hot reference channel matrix
            solution (str, optional): the solution to estimate the beamforming weight
                (Default: ``ref_channel``)
            diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
                (Default: ``True``)
            diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
                (Default: 1e-7)
            eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8

        Returns:
            torch.Tensor: the mvdr beamforming weight matrix
        """
        if diagonal_loading:
            psd_n = self._tik_reg(psd_n, reg=diag_eps, eps=eps)
        if solution == "ref_channel":
            numerator = torch.linalg.solve(psd_n, psd_s)  # psd_n.inv() @ psd_s
            # ws: (..., C, C) / (...,) -> (..., C, C)
            ws = numerator / (mat_trace(numerator)[..., None, None] + eps)
            # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
            beamform_vector = torch.einsum("...fec,...c->...fe", [ws, reference_vector])
        else:
            if solution == "stv_evd":
                stv = self._get_steering_vector_evd(psd_s)
            else:
                stv = self._get_steering_vector_power(psd_s, psd_n, reference_vector)
            # numerator = psd_n.inv() @ stv
            numerator = torch.linalg.solve(psd_n, stv).squeeze(-1)  # (..., freq, channel)
            # denominator = stv^H @ psd_n.inv() @ stv
            denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
275
276
277
            # normalzie the numerator
            scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
            beamform_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
nateanl's avatar
nateanl committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        return beamform_vector

    def _get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor:
        r"""Estimate the steering vector by eigenvalue decomposition.

        Args:
            psd_s (torch.tensor): covariance matrix of speech
                Tensor of dimension (..., freq, channel, channel)

        Returns:
            torch.Tensor: the enhanced STFT
                Tensor of dimension (..., freq, channel, 1)
        """
        w, v = torch.linalg.eig(psd_s)  # (..., freq, channel, channel)
        _, indices = torch.max(w.abs(), dim=-1, keepdim=True)
        indices = indices.unsqueeze(-1)
        stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,)))  # (..., freq, channel, 1)
        return stv

    def _get_steering_vector_power(
            self,
            psd_s: torch.Tensor,
            psd_n: torch.Tensor,
            reference_vector: torch.Tensor
    ) -> torch.Tensor:
        r"""Estimate the steering vector by the power method.

        Args:
            psd_s (torch.tensor): covariance matrix of speech
                Tensor of dimension (..., freq, channel, channel)
            psd_n (torch.Tensor): covariance matrix of noise
                Tensor of dimension (..., freq, channel, channel)
            reference_vector (torch.Tensor): one-hot reference channel matrix

        Returns:
            torch.Tensor: the enhanced STFT
                Tensor of dimension (..., freq, channel, 1)
        """
        phi = torch.linalg.solve(psd_n, psd_s)  # psd_n.inv() @ psd_s
        stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
        stv = stv.unsqueeze(-1)
320
        stv = torch.matmul(phi, stv)
nateanl's avatar
nateanl committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        stv = torch.matmul(psd_s, stv)
        return stv

    def _apply_beamforming_vector(
        self,
        X: torch.Tensor,
        beamform_vector: torch.Tensor
    ) -> torch.Tensor:
        r"""Apply the beamforming weight to the noisy STFT
        Args:
            X (torch.tensor): multi-channel noisy STFT
                Tensor of dimension (..., channel, freq, time)
            beamform_vector (torch.Tensor): beamforming weight matrix
                Tensor of dimension (..., freq, channel)

        Returns:
            torch.Tensor: the enhanced STFT
                Tensor of dimension (..., freq, time)
        """
        # (..., channel) x (..., channel, freq, time) -> (..., freq, time)
        Y = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), X])
        return Y

    def _tik_reg(
        self,
        mat: torch.Tensor,
        reg: float = 1e-7,
        eps: float = 1e-8
    ) -> torch.Tensor:
        """Perform Tikhonov regularization (only modifying real part).
        Args:
            mat (torch.Tensor): input matrix (..., channel, channel)
            reg (float, optional): regularization factor (Default: 1e-8)
            eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)

        Returns:
            torch.Tensor: regularized matrix (..., channel, channel)
        """
        # Add eps
        C = mat.size(-1)
        eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
        with torch.no_grad():
            epsilon = mat_trace(mat).real[..., None, None] * reg
            # in case that correlation_matrix is all-zero
            epsilon = epsilon + eps
        mat = mat + epsilon * eye[..., :, :]
        return mat

    def forward(self, X: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Perform MVDR beamforming.

        Args:
            X (torch.Tensor): the multi-channel STF of the noisy speech.
                Tensor of dimension (..., channel, freq, time)
            mask_s (torch.Tensor): Time-Frequency mask of target speech
                Tensor of dimension (..., freq, time) if multi_mask is ``False``
                or or dimension (..., channel, freq, time) if multi_mask is ``True``
            mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise
                Tensor of dimension (..., freq, time) if multi_mask is ``False``
                or or dimension (..., channel, freq, time) if multi_mask is ``True``
                (Default: None)

        Returns:
            torch.Tensor: The single-channel STFT of the enhanced speech.
                Tensor of dimension (..., freq, time)
        """
        if X.ndim < 3:
            raise ValueError(
                f"Expected at least 3D tensor (..., channel, freq, time). Found: {X.shape}"
            )
        if X.dtype != torch.cdouble:
            raise ValueError(
                f"The type of the input STFT tensor must be ``torch.cdouble``. Found: {X.dtype}"
            )

        if mask_n is None:
            mask_n = 1 - mask_s

        shape = X.size()

        # pack batch
        X = X.reshape(-1, shape[-3], shape[-2], shape[-1])
        if self.multi_mask:
            mask_s = mask_s.reshape(-1, shape[-3], shape[-2], shape[-1])
            mask_n = mask_n.reshape(-1, shape[-3], shape[-2], shape[-1])
        else:
            mask_s = mask_s.reshape(-1, shape[-2], shape[-1])
            mask_n = mask_n.reshape(-1, shape[-2], shape[-1])

        psd_s = self.psd(X, mask_s)  # (..., freq, time, channel, channel)
        psd_n = self.psd(X, mask_n)  # (..., freq, time, channel, channel)

        u = torch.zeros(
            X.size()[:-2],
            device=X.device,
            dtype=torch.cdouble
        )  # (..., channel)
        u[..., self.ref_channel].fill_(1)

        if self.online:
            w_mvdr = self._get_updated_mvdr_vector(
                psd_s,
                psd_n,
                mask_s,
                mask_n,
                u,
                self.solution,
                self.diag_loading,
                self.diag_eps
            )
        else:
            w_mvdr = self._get_mvdr_vector(
                psd_s,
                psd_n,
                u,
                self.solution,
                self.diag_loading,
                self.diag_eps
            )

        Y = self._apply_beamforming_vector(X, w_mvdr)

        # unpack batch
        Y = Y.reshape(shape[:-3] + shape[-2:])

        return Y