Commit c4c512ce authored by Maayan Frid-Adar's avatar Maayan Frid-Adar Committed by Facebook GitHub Bot
Browse files

Fix TB train visualization

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/465

Training visualization was basically activated only for the first training iterations if TRAIN_LOADER_VIS_MAX_IMAGES and TRAIN_LOADER_VIS_WRITE_PERIOD were set to be > 0. because the MAX_IMAGES was taken as the number of samples to log + the allowed number of samples to load overall. So after the first log to TB it was set to 0 and the visualization was not activated for later training steps (ignoring WRITE_PERIOD).

I've added a TRAIN_LOADER_VIS_MAX_BATCH_IMAGES parameter to set a number of samples to visualize each write period up to the max images defined with TRAIN_LOADER_VIS_MAX_IMAGES

Reviewed By: tglik

Differential Revision: D42832903

fbshipit-source-id: 02a0d9aa4ea6d0ee725120916d26b77843a3e8ab
parent 8311dc45
...@@ -19,6 +19,8 @@ def add_tensorboard_default_configs(_C): ...@@ -19,6 +19,8 @@ def add_tensorboard_default_configs(_C):
# This controls max number of images over all batches, be considerate when # This controls max number of images over all batches, be considerate when
# increasing this number because it takes disk space and slows down the training # increasing this number because it takes disk space and slows down the training
_C.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES = 16 _C.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES = 16
# This controls the max number of images to visualize each write period
_C.TENSORBOARD.TRAIN_LOADER_VIS_MAX_BATCH_IMAGES = 16
# Max number of images per dataset to visualize in tensorboard during evaluation # Max number of images per dataset to visualize in tensorboard during evaluation
_C.TENSORBOARD.TEST_VIS_MAX_IMAGES = 16 _C.TENSORBOARD.TEST_VIS_MAX_IMAGES = 16
# Frequency of sending data to tensorboard during evaluation # Frequency of sending data to tensorboard during evaluation
...@@ -135,8 +137,10 @@ class DataLoaderVisWrapper: ...@@ -135,8 +137,10 @@ class DataLoaderVisWrapper:
self.log_frequency = cfg.TENSORBOARD.TRAIN_LOADER_VIS_WRITE_PERIOD self.log_frequency = cfg.TENSORBOARD.TRAIN_LOADER_VIS_WRITE_PERIOD
self.log_limit = cfg.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES self.log_limit = cfg.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES
self.batch_log_limit = cfg.TENSORBOARD.TRAIN_LOADER_VIS_MAX_BATCH_IMAGES
assert self.log_frequency >= 0 assert self.log_frequency >= 0
assert self.log_limit >= 0 assert self.log_limit >= 0
assert self.batch_log_limit >= 0
self._remaining = self.log_limit self._remaining = self.log_limit
def __iter__(self): def __iter__(self):
...@@ -159,7 +163,7 @@ class DataLoaderVisWrapper: ...@@ -159,7 +163,7 @@ class DataLoaderVisWrapper:
): ):
return return
length = min(len(data), self._remaining) length = min(len(data), min(self.batch_log_limit, self._remaining))
data = data[:length] data = data[:length]
self._remaining -= length self._remaining -= length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment