"vscode:/vscode.git/clone" did not exist on "bfd20f98d34daeba38c619d79fe3ed9235c746de"
Commit 54eb76d4 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

Loosening the checks in eval script for CO3Dv2 style eval

Summary:
V2 dataset does not have the concept of known/unseen frames. Test-time conditining is done with train-set frames, which violates the previous check.

Also fixing a corner case in VideoWriter.

Reviewed By: bottler

Differential Revision: D42706976

fbshipit-source-id: d43be3dd3060d18cb9f46d5dcf6252d9f084110f
parent 9dc28f5d
...@@ -219,17 +219,10 @@ def eval_batch( ...@@ -219,17 +219,10 @@ def eval_batch(
frame_type = [frame_type] frame_type = [frame_type]
is_train = is_train_frame(frame_type) is_train = is_train_frame(frame_type)
if not (is_train[0] == is_train).all(): if len(is_train) > 1 and (is_train[1] != is_train[1:]).any():
raise ValueError("All frames in the eval batch have to be either train/test.")
# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)
if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
raise ValueError( raise ValueError(
"For evaluation the first element of the batch has to be" "All (conditioning) frames in the eval batch have to be either train/test."
+ " a target view while the rest should be source views." )
) # TODO: do we need to enforce this?
for k in [ for k in [
"depth_map", "depth_map",
...@@ -362,7 +355,7 @@ def eval_batch( ...@@ -362,7 +355,7 @@ def eval_batch(
results["meta"] = { results["meta"] = {
# store the size of the batch (corresponds to n_src_views+1) # store the size of the batch (corresponds to n_src_views+1)
"batch_size": int(is_known.numel()), "batch_size": len(frame_type),
# store the type of the target frame # store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`. # pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type": str(frame_data.frame_type[0]), "frame_type": str(frame_data.frame_type[0]),
......
...@@ -124,8 +124,11 @@ class VideoWriter: ...@@ -124,8 +124,11 @@ class VideoWriter:
quiet: If `True`, suppresses logging messages. quiet: If `True`, suppresses logging messages.
Returns: Returns:
video_path: The path to the generated video. video_path: The path to the generated video if any frames were added.
Otherwise returns an empty string.
""" """
if self.frame_num == 0:
return ""
regexp = os.path.join(self.cache_dir, self.regexp) regexp = os.path.join(self.cache_dir, self.regexp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment