Commit ff933ab8 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

make visdom optional

Summary: Make Implicitron run without visdom installed.

Reviewed By: shapovalov

Differential Revision: D40587974

fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
parent 46cb5aaa
...@@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways: ...@@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways:
Stats are logged and plotted to the file "train_stats.pdf" in the Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file. same directory. The stats are also saved as part of the checkpoint file.
- Visualizations - Visualizations
Prredictions are plotted to a visdom server running at the Predictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the port specified by the `visdom_server` and `visdom_port` keys in the
config file. config file.
......
...@@ -9,7 +9,7 @@ import copy ...@@ -9,7 +9,7 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch ...@@ -27,7 +27,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate from tabulate import tabulate
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
...@@ -43,7 +45,7 @@ class _Visualizer: ...@@ -43,7 +45,7 @@ class _Visualizer:
visdom_env: str = "eval_debug" visdom_env: str = "eval_debug"
_viz: Visdom = field(init=False) _viz: Optional["Visdom"] = field(init=False)
def __post_init__(self): def __post_init__(self):
self._viz = vis_utils.get_visdom_connection() self._viz = vis_utils.get_visdom_connection()
...@@ -51,6 +53,8 @@ class _Visualizer: ...@@ -51,6 +53,8 @@ class _Visualizer:
def show_rgb( def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
): ):
if self._viz is None:
return
self._viz.images( self._viz.images(
torch.cat( torch.cat(
( (
...@@ -68,7 +72,10 @@ class _Visualizer: ...@@ -68,7 +72,10 @@ class _Visualizer:
def show_depth( def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
): ):
self._viz.images( if self._viz is None:
return
viz = self._viz
viz.images(
torch.cat( torch.cat(
( (
make_depth_image(self.depth_render, loss_mask_now), make_depth_image(self.depth_render, loss_mask_now),
...@@ -80,13 +87,13 @@ class _Visualizer: ...@@ -80,13 +87,13 @@ class _Visualizer:
win="depth_abs" + name_postfix, win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"}, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
) )
self._viz.images( viz.images(
loss_mask_now, loss_mask_now,
env=self.visdom_env, env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask", win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"}, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
) )
self._viz.images( viz.images(
self.depth_mask, self.depth_mask,
env=self.visdom_env, env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd", win="depth_abs" + name_postfix + "_maskd",
...@@ -126,7 +133,7 @@ class _Visualizer: ...@@ -126,7 +133,7 @@ class _Visualizer:
pointcloud_max_points=10000, pointcloud_max_points=10000,
pointcloud_marker_size=1, pointcloud_marker_size=1,
) )
self._viz.plotlyplot( viz.plotlyplot(
plotlyplot, plotlyplot,
env=self.visdom_env, env=self.visdom_env,
win=f"pcl{name_postfix}", win=f"pcl{name_postfix}",
......
...@@ -12,7 +12,7 @@ import logging ...@@ -12,7 +12,7 @@ import logging
import math import math
import warnings import warnings
from dataclasses import field from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch import torch
import tqdm import tqdm
...@@ -34,7 +34,9 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass ...@@ -34,7 +34,9 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import utils as rend_utils from pytorch3d.renderer import utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
from .base_model import ImplicitronModelBase, ImplicitronRender from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase from .feature_extractor import FeatureExtractorBase
...@@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
def visualize( def visualize(
self, self,
viz: Visdom, viz: Optional["Visdom"],
visdom_env_imgs: str, visdom_env_imgs: str,
preds: Dict[str, Any], preds: Dict[str, Any],
prefix: str, prefix: str,
...@@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
preds: predictions dict like returned by forward() preds: predictions dict like returned by forward()
prefix: prepended to the names of images prefix: prepended to the names of images
""" """
if not viz.check_connection(): if viz is None or not viz.check_connection():
logger.info("no visdom server! -> skipping batch vis") logger.info("no visdom server! -> skipping batch vis")
return return
......
...@@ -10,7 +10,7 @@ import logging ...@@ -10,7 +10,7 @@ import logging
import math import math
import os import os
import random import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,9 @@ from pytorch3d.implicitron.tools.vis_utils import ( ...@@ -27,7 +27,9 @@ from pytorch3d.implicitron.tools.vis_utils import (
make_depth_image, make_depth_image,
) )
from tqdm import tqdm from tqdm import tqdm
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T ...@@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
def _show_predictions( def _show_predictions(
preds: List[Dict[str, Any]], preds: List[Dict[str, Any]],
sequence_name: str, sequence_name: str,
viz: Visdom, viz: "Visdom",
viz_env: str = "visualizer", viz_env: str = "visualizer",
predicted_keys: Sequence[str] = ( predicted_keys: Sequence[str] = (
"images_render", "images_render",
...@@ -318,7 +320,7 @@ def _show_predictions( ...@@ -318,7 +320,7 @@ def _show_predictions(
def _generate_prediction_videos( def _generate_prediction_videos(
preds: List[Dict[str, Any]], preds: List[Dict[str, Any]],
sequence_name: str, sequence_name: str,
viz: Optional[Visdom] = None, viz: Optional["Visdom"] = None,
viz_env: str = "visualizer", viz_env: str = "visualizer",
predicted_keys: Sequence[str] = ( predicted_keys: Sequence[str] = (
"images_render", "images_render",
......
...@@ -337,7 +337,7 @@ class Stats(object): ...@@ -337,7 +337,7 @@ class Stats(object):
novisdom = False novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port) viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection(): if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots") print("no visdom server! -> skipping visdom plots")
novisdom = True novisdom = True
......
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import Any, Dict, Tuple from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch import torch
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,9 +42,9 @@ _viz_singleton = None ...@@ -40,9 +42,9 @@ _viz_singleton = None
def get_visdom_connection( def get_visdom_connection(
server: str = "http://localhost", server: str = "http://localhost",
port: int = 8097, port: int = 8097,
) -> Visdom: ) -> Optional["Visdom"]:
""" """
Obtain a connection to a visdom server. Obtain a connection to a visdom server if visdom is installed.
Args: Args:
server: Server address. server: Server address.
...@@ -51,6 +53,15 @@ def get_visdom_connection( ...@@ -51,6 +53,15 @@ def get_visdom_connection(
Returns: Returns:
connection: The connection object. connection: The connection object.
""" """
try:
from visdom import Visdom
except ImportError:
logger.debug("Cannot load visdom")
return None
if server == "None":
return None
global _viz_singleton global _viz_singleton
if _viz_singleton is None: if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port) _viz_singleton = Visdom(server=server, port=port)
...@@ -58,7 +69,7 @@ def get_visdom_connection( ...@@ -58,7 +69,7 @@ def get_visdom_connection(
def visualize_basics( def visualize_basics(
viz: Visdom, viz: "Visdom",
preds: Dict[str, Any], preds: Dict[str, Any],
visdom_env_imgs: str, visdom_env_imgs: str,
title: str = "", title: str = "",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment