Commit eeba91dc authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add complex dtype support in functional autograd test (#2244)

Summary:
In autograd tests, to guarantee the precision, the dtype of Tensors are converted to `torch.float64` if they are real. However, the complex dtype is not considered. This PR adds `self.complex_dtype` support to the inputs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2244

Reviewed By: mthrok

Differential Revision: D34272998

Pulled By: nateanl

fbshipit-source-id: e8698a74d7b8d99ee0fcb5f5cb5f2ffc8c80b9b5
parent c2decba4
......@@ -24,7 +24,7 @@ class Autograd(TestBaseMixin):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment