Commit bcd78085 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add missing imports

parent e3daf724
......@@ -14,7 +14,7 @@
# limitations under the License.
from __future__ import annotations
from typing import Tuple, Any, Sequence, Callable
from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
import torch
......@@ -716,7 +716,7 @@ class Rotation:
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
......@@ -1074,7 +1074,7 @@ class Rigid:
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment