Commit 69b27d16 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

reallow scalar background color for point rendering

Summary: A scalar background color is not meant to be allowed for the point renderer. It used to be ignored with a warning, but a recent code change made it an error. It was being used, at least in the black (value=0.0) case. Re-enable it.

Reviewed By: nikhilaravi

Differential Revision: D34519651

fbshipit-source-id: d37dcf145bb7b8999c9265cf8fc39b084059dd18
parent 84a569c0
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -85,7 +84,10 @@ def _add_background_color_to_images(pix_idxs, images, background_color): ...@@ -85,7 +84,10 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if not torch.is_tensor(background_color): if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color) background_color = images.new_tensor(background_color)
if len(background_color.shape) != 1: if background_color.ndim == 0:
background_color = background_color.expand(images.shape[1])
if background_color.ndim > 1:
raise ValueError("Wrong shape of background_color") raise ValueError("Wrong shape of background_color")
background_color = background_color.to(images) background_color = background_color.to(images)
...@@ -98,7 +100,8 @@ def _add_background_color_to_images(pix_idxs, images, background_color): ...@@ -98,7 +100,8 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if images.shape[1] != background_color.shape[0]: if images.shape[1] != background_color.shape[0]:
raise ValueError( raise ValueError(
f"background color has {background_color.shape[0] } channels not {images.shape[1]}" "Background color has %s channels not %s"
% (background_color.shape[0], images.shape[1])
) )
num_background_pixels = background_mask.sum() num_background_pixels = background_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment