Commit cb9f9ca2 authored by zhanggzh's avatar zhanggzh
Browse files

add dcnv3 src code ande update readme file

parent 31c962dc
This diff is collapsed.
...@@ -5,3 +5,4 @@ ...@@ -5,3 +5,4 @@
# -------------------------------------------------------- # --------------------------------------------------------
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch
__version__ = '0.1'
...@@ -52,8 +52,13 @@ def check_forward_equal_with_pytorch_double(): ...@@ -52,8 +52,13 @@ def check_forward_equal_with_pytorch_double():
fwdok = torch.allclose(output_cuda, output_pytorch) fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max() max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / #max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max() # output_pytorch.abs()).max()
non_zero_mask = output_pytorch.abs() > 0
if non_zero_mask.any():
max_rel_err = ((output_cuda[non_zero_mask] - output_pytorch[non_zero_mask]).abs() / output_pytorch[non_zero_mask].abs()).max()
else:
max_rel_err = 0.0 # 如果所有值都为零,则相对误差为 0
print('>>> forward double') print('>>> forward double')
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment