• Zhaoheng Ni's avatar
    Fix return dtype in MVDR module (#2376) · 2f4eb4ac
    Zhaoheng Ni authored
    Summary:
    Address https://github.com/pytorch/audio/issues/2375
    The MVDR module internally transforms the dtype of complex tensors to `torch.complex128` for computation and transforms it back to the original dtype before returning the Tensor. However, it didn't convert back successfully due to `specgram_enhanced.to(dtype)`, which should be `specgram_enhanced = specgram_enhanced.to(dtype)`. Fix it to make the output dtype consistent with original input.
    
    Pull Request resolved: https://github.com/pytorch/audio/pull/2376
    
    Reviewed By: hwangjeff
    
    Differential Revision: D36280851
    
    Pulled By: nateanl
    
    fbshipit-source-id: 553d1b98f899547209a4e3ebc59920c7ef1f3112
    2f4eb4ac
_transforms.py 91.2 KB