Commit bcd78085 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add missing imports

parent e3daf724
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from typing import Tuple, Any, Sequence, Callable from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -716,7 +716,7 @@ class Rotation: ...@@ -716,7 +716,7 @@ class Rotation:
return Rotation(rot_mats=rot_mats, quats=None) return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self, def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor] fn: Callable[torch.Tensor, torch.Tensor]
) -> Rotation: ) -> Rotation:
""" """
Apply a Tensor -> Tensor function to underlying rotation tensors, Apply a Tensor -> Tensor function to underlying rotation tensors,
...@@ -1074,7 +1074,7 @@ class Rigid: ...@@ -1074,7 +1074,7 @@ class Rigid:
return Rigid(rot_inv, -1 * trn_inv) return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor] fn: Callable[torch.Tensor, torch.Tensor]
) -> Rigid: ) -> Rigid:
""" """
Apply a Tensor -> Tensor function to underlying translation and Apply a Tensor -> Tensor function to underlying translation and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment