• moto's avatar
    Use fused op in lfilter (#517) · 86d54160
    moto authored
    This improves the speed of `lfilter` (and functions that use `lfilter`, such as `biquad`) by 10%.
    
    * Before (23.4369 seconds for `lfilter` call)
    
    Breakdown
    
    ```
       720|    220501|       4.4464|   2.0165e-05| 18.97%|    for i_sample, o0 in enumerate(input_signal_windows.t()):
    (call)|         1|  7.86781e-05|  7.86781e-05|  0.00%|# /scratch/moto/pytorch/torch/tensor.py:460 __iter__
    (call)|    220500|      2.72458|  1.23564e-05| 11.62%|# /scratch/moto/pytorch/torch/tensor.py:474 <lambda>
       721|    220500|      2.80982|   1.2743e-05| 11.99%|        windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
       722|    220500|      4.92106|  2.23177e-05| 21.00%|        o0.sub_(torch.mv(windowed_output_signal, a_coeffs_flipped))
       723|    220500|      3.72974|  1.69149e-05| 15.91%|        o0.div_(a_coeffs[0])
       724|         0|            0|            0|  0.00%|
       725|    220500|      4.77714|   2.1665e-05| 20.38%|        padded_output_waveform[:, i_sample + n_order - 1] = o0
    ```
    
    * After (20.8405 seconds for `lfilter` call)
    
    Breakdown
    
    ```
       720|    220501|      4.40834|  1.99924e-05| 21.15%|    for i_sample, o0 in enumerate(input_signal_windows.t()):
    (call)|         1|  7.31945e-05|  7.31945e-05|  0.00%|# /scratch/moto/pytorch/torch/tensor.py:460 __iter__
    (call)|    220500|      2.68595|  1.21812e-05| 12.89%|# /scratch/moto/pytorch/torch/tensor.py:474 <lambda>
       721|    220500|      2.97357|  1.34856e-05| 14.27%|        windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
       722|    220500|      2.63567|  1.19531e-05| 12.65%|        o0.addmv_(windowed_output_signal, a_coeffs_flipped)
       723|    220500|       3.4228|  1.55229e-05| 16.42%|        o0.div_(a_coeffs[0])
       724|         0|            0|            0|  0.00%|
       725|    220500|      4.68726|  2.12574e-05| 22.49%|        padded_output_waveform[:, i_sample + n_order - 1] = o0
    ```
    
    * Script
    
    ```python
    import pprofile
    
    import torch
    import torchaudio
    import torchaudio.functional as F
    
    def profile_biquad():
        waveform, sr = torchaudio.load('test/assets/100Hz_44100Hz_16bit_05sec.wav', normalization=True)
    
        prof = pprofile.Profile()
        with prof():
            F.equalizer_biquad(waveform, sr, 3000, 1, 0.707)
        prof.print_stats()
    
    profile_biquad()
    ```
    
    * See also
    
    https://github.com/pytorch/audio/issues/260#issuecomment-610074110
    86d54160
functional.py 54.4 KB