Unverified Commit 86d54160 authored by moto's avatar moto Committed by GitHub
Browse files

Use fused op in lfilter (#517)

This improves the speed of `lfilter` (and functions that use `lfilter`, such as `biquad`) by 10%.

* Before (23.4369 seconds for `lfilter` call)

Breakdown

```
   720|    220501|       4.4464|   2.0165e-05| 18.97%|    for i_sample, o0 in enumerate(input_signal_windows.t()):
(call)|         1|  7.86781e-05|  7.86781e-05|  0.00%|# /scratch/moto/pytorch/torch/tensor.py:460 __iter__
(call)|    220500|      2.72458|  1.23564e-05| 11.62%|# /scratch/moto/pytorch/torch/tensor.py:474 <lambda>
   721|    220500|      2.80982|   1.2743e-05| 11.99%|        windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
   722|    220500|      4.92106|  2.23177e-05| 21.00%|        o0.sub_(torch.mv(windowed_output_signal, a_coeffs_flipped))
   723|    220500|      3.72974|  1.69149e-05| 15.91%|        o0.div_(a_coeffs[0])
   724|         0|            0|            0|  0.00%|
   725|    220500|      4.77714|   2.1665e-05| 20.38%|        padded_output_waveform[:, i_sample + n_order - 1] = o0
```

* After (20.8405 seconds for `lfilter` call)

Breakdown

```
   720|    220501|      4.40834|  1.99924e-05| 21.15%|    for i_sample, o0 in enumerate(input_signal_windows.t()):
(call)|         1|  7.31945e-05|  7.31945e-05|  0.00%|# /scratch/moto/pytorch/torch/tensor.py:460 __iter__
(call)|    220500|      2.68595|  1.21812e-05| 12.89%|# /scratch/moto/pytorch/torch/tensor.py:474 <lambda>
   721|    220500|      2.97357|  1.34856e-05| 14.27%|        windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
   722|    220500|      2.63567|  1.19531e-05| 12.65%|        o0.addmv_(windowed_output_signal, a_coeffs_flipped)
   723|    220500|       3.4228|  1.55229e-05| 16.42%|        o0.div_(a_coeffs[0])
   724|         0|            0|            0|  0.00%|
   725|    220500|      4.68726|  2.12574e-05| 22.49%|        padded_output_waveform[:, i_sample + n_order - 1] = o0
```

* Script

```python
import pprofile

import torch
import torchaudio
import torchaudio.functional as F

def profile_biquad():
    waveform, sr = torchaudio.load('test/assets/100Hz_44100Hz_16bit_05sec.wav', normalization=True)

    prof = pprofile.Profile()
    with prof():
        F.equalizer_biquad(waveform, sr, 3000, 1, 0.707)
    prof.print_stats()

profile_biquad()
```

* See also

https://github.com/pytorch/audio/issues/260#issuecomment-610074110
parent f37d37d6
...@@ -745,7 +745,7 @@ def lfilter( ...@@ -745,7 +745,7 @@ def lfilter(
for i_sample, o0 in enumerate(input_signal_windows.t()): for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)] windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
o0.sub_(torch.mv(windowed_output_signal, a_coeffs_flipped)) o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
o0.div_(a_coeffs[0]) o0.div_(a_coeffs[0])
padded_output_waveform[:, i_sample + n_order - 1] = o0 padded_output_waveform[:, i_sample + n_order - 1] = o0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment