Commit 54c75b41 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

GM error for unbatched inputs

Summary: Error when sending an unbatched FrameData through GM.

Reviewed By: shapovalov

Differential Revision: D38036286

fbshipit-source-id: b8d280c61fbbefdc112c57ccd630ab3ccce7b44e
parent 3783437d
...@@ -765,6 +765,17 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -765,6 +765,17 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
Returns: Returns:
Modified image_rgb, fg_mask, depth_map Modified image_rgb, fg_mask, depth_map
""" """
if image_rgb is not None and image_rgb.ndim == 3:
# The FrameData object is used for both frames and batches of frames,
# and a user might get this error if those were confused.
# Perhaps a user has a FrameData `fd` representing a single frame and
# wrote something like `model(**fd)` instead of
# `model(**fd.collate([fd]))`.
raise ValueError(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated."
)
fg_mask = fg_probability fg_mask = fg_probability
if fg_mask is not None and self.mask_threshold > 0.0: if fg_mask is not None and self.mask_threshold > 0.0:
# threshold masks # threshold masks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment