Commit b2dc5202 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

lints

Summary: lint issues (mostly flake) in implicitron

Reviewed By: patricklabatut

Differential Revision: D37920948

fbshipit-source-id: 8cb3c2a2838d111c80a211c98a404c210d4649ed
parent 8597d4c5
......@@ -833,7 +833,7 @@ def _load_1bit_png_mask(file: str) -> np.ndarray:
return mask
def _load_depth_mask(path) -> np.ndarray:
def _load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)
......
......@@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import field
from typing import List, Optional
from typing import Optional, Tuple
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
......@@ -206,7 +205,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 1.0
n_hidden_neurons_xyz: int = 256
n_layers_xyz: int = 8
append_xyz: List[int] = field(default_factory=lambda: [5])
append_xyz: Tuple[int, ...] = (5,)
def _construct_xyz_encoder(self, input_dim: int):
return MLPWithInputSkips(
......@@ -224,7 +223,7 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 2.0
n_hidden_neurons_xyz: int = 80
n_layers_xyz: int = 2
append_xyz: List[int] = field(default_factory=lambda: [1])
append_xyz: Tuple[int, ...] = (1,)
def _construct_xyz_encoder(self, input_dim: int):
return TransformerWithInputSkips(
......@@ -286,7 +285,7 @@ class MLPWithInputSkips(torch.nn.Module):
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: List[int] = [5],
input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False,
no_last_relu=False,
):
......@@ -362,7 +361,7 @@ class TransformerWithInputSkips(torch.nn.Module):
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 64,
input_skips: List[int] = [5],
input_skips: Tuple[int, ...] = (5,),
dim_down_factor: float = 1,
):
"""
......
......@@ -7,11 +7,10 @@
from typing import List
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, RendererOutput
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_point_refiner import RayPointRefiner
from .raymarcher import RaymarcherBase
......@@ -107,7 +106,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper] = [],
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Optional, Tuple
import torch
......
......@@ -59,7 +59,7 @@ def cleanup_eval_depth(
good_df_thr = std * sigma
good_depth = (df <= good_df_thr).float() * pcl_mask
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
......
......@@ -200,9 +200,6 @@ def _visdom_plot_scene(
viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
......
......@@ -202,7 +202,7 @@ def neg_iou_loss(
return 1.0 - iou(predict, target, mask=mask)
def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor:
def safe_sqrt(A: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
"""
performs safe differentiable sqrt
"""
......
......@@ -20,12 +20,10 @@ logger = logging.getLogger(__name__)
def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats
try:
stats = Stats.load(flstats)
except:
logger.info("Cant load stats! %s" % flstats)
stats = None
return stats
if not os.path.isfile(flstats):
return None
return Stats.load(flstats)
def get_model_path(fl) -> str:
......@@ -40,7 +38,7 @@ def get_optimizer_path(fl) -> str:
return flopt
def get_stats_path(fl, eval_results: bool = False):
def get_stats_path(fl, eval_results: bool = False) -> str:
fl = os.path.splitext(fl)[0]
if eval_results:
for postfix in ("_2", ""):
......
......@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List
from typing import Any, Dict, Tuple
import torch
from visdom import Visdom
......@@ -60,14 +60,14 @@ def visualize_basics(
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
visualize_preds_keys: List[str] = [
visualize_preds_keys: Tuple[str, ...] = (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
],
),
store_history: bool = False,
) -> None:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment