Commit c18fba3d authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

restored comment

parent 67370881
......@@ -139,6 +139,12 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert len(psi.shape) == 3
assert len(x.shape) == 4
psi = psi.to(x.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment