Commit e9fb6c27 authored by Pyre Bot Jr's avatar Pyre Bot Jr Committed by Facebook GitHub Bot
Browse files

Add annotations to `vision/fair/pytorch3d`

Reviewed By: shannonzhu

Differential Revision: D33970393

fbshipit-source-id: 9b4dfaccfc3793fd37705a923d689cb14c9d26ba
parent c2862ff4
...@@ -98,7 +98,9 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover ...@@ -98,7 +98,9 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
return collated_dict return collated_dict
def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover def compute_extrinsic_matrix(
azimuth: float, elevation: float, distance: float
): # pragma: no cover
""" """
Copied from meshrcnn codebase: Copied from meshrcnn codebase:
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L96 https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L96
...@@ -138,6 +140,7 @@ def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover ...@@ -138,6 +140,7 @@ def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover
# rotates the model 90 degrees about the x axis. To compensate for this quirk we # rotates the model 90 degrees about the x axis. To compensate for this quirk we
# roll that rotation into the extrinsic matrix here # roll that rotation into the extrinsic matrix here
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
# pyre-fixme[16]: `Tensor` has no attribute `mm`.
RT = RT.mm(rot.to(RT)) RT = RT.mm(rot.to(RT))
return RT return RT
...@@ -384,7 +387,7 @@ def voxelize(voxel_coords, P, V): # pragma: no cover ...@@ -384,7 +387,7 @@ def voxelize(voxel_coords, P, V): # pragma: no cover
return voxels return voxels
def project_verts(verts, P, eps=1e-1): # pragma: no cover def project_verts(verts, P, eps: float = 1e-1): # pragma: no cover
""" """
Copied from meshrcnn codebase: Copied from meshrcnn codebase:
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L159 https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L159
......
...@@ -32,7 +32,7 @@ _Aux = namedtuple( ...@@ -32,7 +32,7 @@ _Aux = namedtuple(
) )
def _format_faces_indices(faces_indices, max_index, device, pad_value=None): def _format_faces_indices(faces_indices, max_index: int, device, pad_value=None):
""" """
Format indices and check for invalid values. Indices can refer to Format indices and check for invalid values. Indices can refer to
values in one of the face properties: vertices, textures or normals. values in one of the face properties: vertices, textures or normals.
...@@ -57,6 +57,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): ...@@ -57,6 +57,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
) )
if pad_value is not None: if pad_value is not None:
# pyre-fixme[28]: Unexpected keyword argument `dim`.
mask = faces_indices.eq(pad_value).all(dim=-1) mask = faces_indices.eq(pad_value).all(dim=-1)
# Change to 0 based indexing. # Change to 0 based indexing.
...@@ -66,6 +67,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): ...@@ -66,6 +67,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices[(faces_indices < 0)] += max_index faces_indices[(faces_indices < 0)] += max_index
if pad_value is not None: if pad_value is not None:
# pyre-fixme[61]: `mask` is undefined, or not always defined.
faces_indices[mask] = pad_value faces_indices[mask] = pad_value
return _check_faces_indices(faces_indices, max_index, pad_value) return _check_faces_indices(faces_indices, max_index, pad_value)
...@@ -73,7 +75,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): ...@@ -73,7 +75,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
def load_obj( def load_obj(
f, f,
load_textures=True, load_textures: bool = True,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
...@@ -351,7 +353,7 @@ def _parse_face( ...@@ -351,7 +353,7 @@ def _parse_face(
faces_normals_idx, faces_normals_idx,
faces_textures_idx, faces_textures_idx,
faces_materials_idx, faces_materials_idx,
): ) -> None:
face = tokens[1:] face = tokens[1:]
face_list = [f.split("/") for f in face] face_list = [f.split("/") for f in face]
face_verts = [] face_verts = []
...@@ -546,7 +548,7 @@ def _load_materials( ...@@ -546,7 +548,7 @@ def _load_materials(
def _load_obj( def _load_obj(
f_obj, f_obj,
*, *,
data_dir, data_dir: str,
load_textures: bool = True, load_textures: bool = True,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
......
...@@ -463,7 +463,9 @@ def _read_ply_element_ascii(f, definition: _PlyElementType): ...@@ -463,7 +463,9 @@ def _read_ply_element_ascii(f, definition: _PlyElementType):
return data return data
def _read_raw_array(f, aim: str, length: int, dtype: type = np.uint8, dtype_size=1): def _read_raw_array(
f, aim: str, length: int, dtype: type = np.uint8, dtype_size: int = 1
):
""" """
Read [length] elements from a file. Read [length] elements from a file.
......
...@@ -28,7 +28,7 @@ def nullcontext(x): ...@@ -28,7 +28,7 @@ def nullcontext(x):
PathOrStr = Union[pathlib.Path, str] PathOrStr = Union[pathlib.Path, str]
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]: def _open_file(f, path_manager: PathManager, mode: str = "r") -> ContextManager[IO]:
if isinstance(f, str): if isinstance(f, str):
f = path_manager.open(f, mode) f = path_manager.open(f, mode)
return contextlib.closing(f) return contextlib.closing(f)
......
...@@ -14,7 +14,7 @@ from pytorch3d.structures.pointclouds import Pointclouds ...@@ -14,7 +14,7 @@ from pytorch3d.structures.pointclouds import Pointclouds
def _validate_chamfer_reduction_inputs( def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str batch_reduction: Union[str, None], point_reduction: str
): ) -> None:
"""Check the requested reductions are valid. """Check the requested reductions are valid.
Args: Args:
......
...@@ -106,7 +106,7 @@ def knn_points( ...@@ -106,7 +106,7 @@ def knn_points(
version: int = -1, version: int = -1,
return_nn: bool = False, return_nn: bool = False,
return_sorted: bool = True, return_sorted: bool = True,
): ) -> _KNN:
""" """
K-Nearest neighbors on point clouds. K-Nearest neighbors on point clouds.
......
...@@ -166,7 +166,7 @@ def estimate_pointcloud_local_coord_frames( ...@@ -166,7 +166,7 @@ def estimate_pointcloud_local_coord_frames(
return curvatures, local_coord_frames return curvatures, local_coord_frames
def _disambiguate_vector_directions(pcl, knns, vecs): def _disambiguate_vector_directions(pcl, knns, vecs: float) -> float:
""" """
Disambiguates normal directions according to [1]. Disambiguates normal directions according to [1].
...@@ -180,6 +180,7 @@ def _disambiguate_vector_directions(pcl, knns, vecs): ...@@ -180,6 +180,7 @@ def _disambiguate_vector_directions(pcl, knns, vecs):
# each element of the neighborhood # each element of the neighborhood
df = knns - pcl[:, :, None] df = knns - pcl[:, :, None]
# projection of the difference on the principal direction # projection of the difference on the principal direction
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
proj = (vecs[:, :, None] * df).sum(3) proj = (vecs[:, :, None] * df).sum(3)
# check how many projections are positive # check how many projections are positive
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True) n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)
......
...@@ -479,7 +479,7 @@ def _check_points_to_volumes_inputs( ...@@ -479,7 +479,7 @@ def _check_points_to_volumes_inputs(
volume_features: torch.Tensor, volume_features: torch.Tensor,
grid_sizes: torch.LongTensor, grid_sizes: torch.LongTensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
): ) -> None:
max_grid_size = grid_sizes.max(dim=0).values max_grid_size = grid_sizes.max(dim=0).values
if torch.prod(max_grid_size) > volume_densities.shape[1]: if torch.prod(max_grid_size) > volume_densities.shape[1]:
......
...@@ -400,7 +400,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None): ...@@ -400,7 +400,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
return verts_idx return verts_idx
def create_faces_index(faces_per_mesh, device=None): def create_faces_index(faces_per_mesh: int, device=None):
""" """
Helper function to group the faces indices for each mesh. New faces are Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have stacked at the end of the original faces tensor, so in order to have
...@@ -417,7 +417,9 @@ def create_faces_index(faces_per_mesh, device=None): ...@@ -417,7 +417,9 @@ def create_faces_index(faces_per_mesh, device=None):
""" """
# e.g. faces_per_mesh = [2, 5, 3] # e.g. faces_per_mesh = [2, 5, 3]
# pyre-fixme[16]: `int` has no attribute `sum`.
F = faces_per_mesh.sum() # e.g. 10 F = faces_per_mesh.sum() # e.g. 10
# pyre-fixme[16]: `int` has no attribute `cumsum`.
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10) faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone() switch1_idx = faces_per_mesh_cumsum.clone()
......
...@@ -150,7 +150,7 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]): ...@@ -150,7 +150,7 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
return X, num_points return X, num_points
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]): def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]) -> bool:
"""Checks whether the input `pcl` is an instance of `Pointclouds` """Checks whether the input `pcl` is an instance of `Pointclouds`
by checking the existence of `points_padded` and `num_points_per_cloud` by checking the existence of `points_padded` and `num_points_per_cloud`
functions. functions.
......
...@@ -427,10 +427,10 @@ class CamerasBase(TensorProperties): ...@@ -427,10 +427,10 @@ class CamerasBase(TensorProperties):
def OpenGLPerspectiveCameras( def OpenGLPerspectiveCameras(
znear=1.0, znear: float = 1.0,
zfar=100.0, zfar: float = 100.0,
aspect_ratio=1.0, aspect_ratio: float = 1.0,
fov=60.0, fov: float = 60.0,
degrees: bool = True, degrees: bool = True,
R: torch.Tensor = _R, R: torch.Tensor = _R,
T: torch.Tensor = _T, T: torch.Tensor = _T,
...@@ -709,12 +709,12 @@ class FoVPerspectiveCameras(CamerasBase): ...@@ -709,12 +709,12 @@ class FoVPerspectiveCameras(CamerasBase):
def OpenGLOrthographicCameras( def OpenGLOrthographicCameras(
znear=1.0, znear: float = 1.0,
zfar=100.0, zfar: float = 100.0,
top=1.0, top: float = 1.0,
bottom=-1.0, bottom: float = -1.0,
left=-1.0, left: float = -1.0,
right=1.0, right: float = 1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R: torch.Tensor = _R, R: torch.Tensor = _R,
T: torch.Tensor = _T, T: torch.Tensor = _T,
...@@ -956,7 +956,7 @@ Note that the MultiView Cameras accept parameters in NDC space. ...@@ -956,7 +956,7 @@ Note that the MultiView Cameras accept parameters in NDC space.
def SfMPerspectiveCameras( def SfMPerspectiveCameras(
focal_length=1.0, focal_length: float = 1.0,
principal_point=((0.0, 0.0),), principal_point=((0.0, 0.0),),
R: torch.Tensor = _R, R: torch.Tensor = _R,
T: torch.Tensor = _T, T: torch.Tensor = _T,
...@@ -1194,7 +1194,7 @@ class PerspectiveCameras(CamerasBase): ...@@ -1194,7 +1194,7 @@ class PerspectiveCameras(CamerasBase):
def SfMOrthographicCameras( def SfMOrthographicCameras(
focal_length=1.0, focal_length: float = 1.0,
principal_point=((0.0, 0.0),), principal_point=((0.0, 0.0),),
R: torch.Tensor = _R, R: torch.Tensor = _R,
T: torch.Tensor = _T, T: torch.Tensor = _T,
...@@ -1645,9 +1645,9 @@ def look_at_rotation( ...@@ -1645,9 +1645,9 @@ def look_at_rotation(
def look_at_view_transform( def look_at_view_transform(
dist=1.0, dist: float = 1.0,
elev=0.0, elev: float = 0.0,
azim=0.0, azim: float = 0.0,
degrees: bool = True, degrees: bool = True,
eye: Optional[Sequence] = None, eye: Optional[Sequence] = None,
at=((0, 0, 0),), # (1, 3) at=((0, 0, 0),), # (1, 3)
......
...@@ -162,7 +162,7 @@ class AbsorptionOnlyRaymarcher(torch.nn.Module): ...@@ -162,7 +162,7 @@ class AbsorptionOnlyRaymarcher(torch.nn.Module):
return opacities return opacities
def _shifted_cumprod(x, shift=1): def _shifted_cumprod(x, shift: int = 1):
""" """
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
ones and removes `shift` trailing elements to/from the last dimension ones and removes `shift` trailing elements to/from the last dimension
...@@ -177,7 +177,7 @@ def _shifted_cumprod(x, shift=1): ...@@ -177,7 +177,7 @@ def _shifted_cumprod(x, shift=1):
def _check_density_bounds( def _check_density_bounds(
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0) rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
): ) -> None:
""" """
Checks whether the elements of `rays_densities` range within `bounds`. Checks whether the elements of `rays_densities` range within `bounds`.
If not issues a warning. If not issues a warning.
...@@ -197,7 +197,7 @@ def _check_raymarcher_inputs( ...@@ -197,7 +197,7 @@ def _check_raymarcher_inputs(
features_can_be_none: bool = False, features_can_be_none: bool = False,
z_can_be_none: bool = False, z_can_be_none: bool = False,
density_1d: bool = True, density_1d: bool = True,
): ) -> None:
""" """
Checks the validity of the inputs to raymarching algorithms. Checks the validity of the inputs to raymarching algorithms.
""" """
......
...@@ -98,7 +98,7 @@ def _validate_ray_bundle_variables( ...@@ -98,7 +98,7 @@ def _validate_ray_bundle_variables(
rays_origins: torch.Tensor, rays_origins: torch.Tensor,
rays_directions: torch.Tensor, rays_directions: torch.Tensor,
rays_lengths: torch.Tensor, rays_lengths: torch.Tensor,
): ) -> None:
""" """
Validate the shapes of RayBundle variables Validate the shapes of RayBundle variables
`rays_origins`, `rays_directions`, and `rays_lengths`. `rays_origins`, `rays_directions`, and `rays_lengths`.
......
...@@ -323,7 +323,7 @@ class AmbientLights(TensorProperties): ...@@ -323,7 +323,7 @@ class AmbientLights(TensorProperties):
return torch.zeros_like(points) return torch.zeros_like(points)
def _validate_light_properties(obj): def _validate_light_properties(obj) -> None:
props = ("ambient_color", "diffuse_color", "specular_color") props = ("ambient_color", "diffuse_color", "specular_color")
for n in props: for n in props:
t = getattr(obj, n) t = getattr(obj, n)
......
...@@ -301,7 +301,7 @@ def TexturedSoftPhongShader( ...@@ -301,7 +301,7 @@ def TexturedSoftPhongShader(
lights: Optional[TensorProperties] = None, lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None, materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None, blend_params: Optional[BlendParams] = None,
): ) -> SoftPhongShader:
""" """
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
Preserving TexturedSoftPhongShader as a function for backwards compatibility. Preserving TexturedSoftPhongShader as a function for backwards compatibility.
......
...@@ -1557,7 +1557,7 @@ class Meshes: ...@@ -1557,7 +1557,7 @@ class Meshes:
raise ValueError("Meshes does not have textures") raise ValueError("Meshes does not have textures")
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) -> Meshes:
""" """
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all must all be on the same device. If include_textures is true, they must all
......
...@@ -1224,7 +1224,7 @@ class Pointclouds: ...@@ -1224,7 +1224,7 @@ class Pointclouds:
return coord_inside.all(dim=-1) return coord_inside.all(dim=-1)
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]): def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds:
""" """
Merge a list of Pointclouds objects into a single batched Pointclouds Merge a list of Pointclouds objects into a single batched Pointclouds
object. All pointclouds must be on the same device. object. All pointclouds must be on the same device.
......
...@@ -10,7 +10,7 @@ from typing import Tuple ...@@ -10,7 +10,7 @@ from typing import Tuple
import torch import torch
DEFAULT_ACOS_BOUND = 1.0 - 1e-4 DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4
def acos_linear_extrapolation( def acos_linear_extrapolation(
......
...@@ -754,7 +754,7 @@ def _broadcast_bmm(a, b): ...@@ -754,7 +754,7 @@ def _broadcast_bmm(a, b):
@torch.no_grad() @torch.no_grad()
def _check_valid_rotation_matrix(R, tol: float = 1e-7): def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None:
""" """
Determine if R is a valid rotation matrix by checking it satisfies the Determine if R is a valid rotation matrix by checking it satisfies the
following conditions: following conditions:
......
...@@ -24,7 +24,7 @@ from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene ...@@ -24,7 +24,7 @@ from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle] Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]
def _get_struct_len(struct: Struct): # pragma: no cover def _get_struct_len(struct: Struct) -> int: # pragma: no cover
""" """
Returns the length (usually corresponds to the batch size) of the input structure. Returns the length (usually corresponds to the batch size) of the input structure.
""" """
...@@ -358,8 +358,14 @@ def plot_scene( ...@@ -358,8 +358,14 @@ def plot_scene(
up_y = _scale_camera_to_bounds(up_y, y_range, False) up_y = _scale_camera_to_bounds(up_y, y_range, False)
up_z = _scale_camera_to_bounds(up_z, z_range, False) up_z = _scale_camera_to_bounds(up_z, z_range, False)
# pyre-fixme[6]: For 2nd param expected `Dict[str, int]` but got
# `Dict[str, float]`.
camera["eye"] = {"x": eye_x, "y": eye_y, "z": eye_z} camera["eye"] = {"x": eye_x, "y": eye_y, "z": eye_z}
# pyre-fixme[6]: For 2nd param expected `Dict[str, int]` but got
# `Dict[str, float]`.
camera["center"] = {"x": at_x, "y": at_y, "z": at_z} camera["center"] = {"x": at_x, "y": at_y, "z": at_z}
# pyre-fixme[6]: For 2nd param expected `Dict[str, int]` but got
# `Dict[str, float]`.
camera["up"] = {"x": up_x, "y": up_y, "z": up_z} camera["up"] = {"x": up_x, "y": up_y, "z": up_z}
current_layout.update( current_layout.update(
...@@ -510,7 +516,7 @@ def _add_struct_from_batch( ...@@ -510,7 +516,7 @@ def _add_struct_from_batch(
subplot_title: str, subplot_title: str,
scene_dictionary: Dict[str, Dict[str, Struct]], scene_dictionary: Dict[str, Dict[str, Struct]],
trace_idx: int = 1, trace_idx: int = 1,
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds the struct corresponding to the given scene_num index to Adds the struct corresponding to the given scene_num index to
a provided scene_dictionary to be passed in to plot_scene a provided scene_dictionary to be passed in to plot_scene
...@@ -567,7 +573,7 @@ def _add_mesh_trace( ...@@ -567,7 +573,7 @@ def _add_mesh_trace(
subplot_idx: int, subplot_idx: int,
ncols: int, ncols: int,
lighting: Lighting, lighting: Lighting,
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds a trace rendering a Meshes object to the passed in figure, with Adds a trace rendering a Meshes object to the passed in figure, with
a given name and in a specific subplot. a given name and in a specific subplot.
...@@ -641,7 +647,7 @@ def _add_pointcloud_trace( ...@@ -641,7 +647,7 @@ def _add_pointcloud_trace(
ncols: int, ncols: int,
max_points_per_pointcloud: int, max_points_per_pointcloud: int,
marker_size: int, marker_size: int,
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds a trace rendering a Pointclouds object to the passed in figure, with Adds a trace rendering a Pointclouds object to the passed in figure, with
a given name and in a specific subplot. a given name and in a specific subplot.
...@@ -703,7 +709,7 @@ def _add_camera_trace( ...@@ -703,7 +709,7 @@ def _add_camera_trace(
subplot_idx: int, subplot_idx: int,
ncols: int, ncols: int,
camera_scale: float, camera_scale: float,
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds a trace rendering a Cameras object to the passed in figure, with Adds a trace rendering a Cameras object to the passed in figure, with
a given name and in a specific subplot. a given name and in a specific subplot.
...@@ -761,7 +767,7 @@ def _add_ray_bundle_trace( ...@@ -761,7 +767,7 @@ def _add_ray_bundle_trace(
max_points_per_ray: int, max_points_per_ray: int,
marker_size: int, marker_size: int,
line_width: int, line_width: int,
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds a trace rendering a RayBundle object to the passed in figure, with Adds a trace rendering a RayBundle object to the passed in figure, with
a given name and in a specific subplot. a given name and in a specific subplot.
...@@ -918,7 +924,7 @@ def _update_axes_bounds( ...@@ -918,7 +924,7 @@ def _update_axes_bounds(
verts_center: torch.Tensor, verts_center: torch.Tensor,
max_expand: float, max_expand: float,
current_layout: go.Scene, # pyre-ignore[11] current_layout: go.Scene, # pyre-ignore[11]
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Takes in the vertices' center point and max spread, and the current plotly figure Takes in the vertices' center point and max spread, and the current plotly figure
layout and updates the layout to have bounds that include all traces for that subplot. layout and updates the layout to have bounds that include all traces for that subplot.
...@@ -956,7 +962,7 @@ def _update_axes_bounds( ...@@ -956,7 +962,7 @@ def _update_axes_bounds(
def _scale_camera_to_bounds( def _scale_camera_to_bounds(
coordinate: float, axis_bounds: Tuple[float, float], is_position: bool coordinate: float, axis_bounds: Tuple[float, float], is_position: bool
): # pragma: no cover ) -> float: # pragma: no cover
""" """
We set our plotly plot's axes' bounding box to [-1,1]x[-1,1]x[-1,1]. As such, We set our plotly plot's axes' bounding box to [-1,1]x[-1,1]x[-1,1]. As such,
the plotly camera location has to be scaled accordingly to have its world coordinates the plotly camera location has to be scaled accordingly to have its world coordinates
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment