Commit c0bb49b5 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

API for accessing frames in order in Implicitron dataset.

Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility.

Reviewed By: davnov134

Differential Revision: D35012121

fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
parent 05f656c0
......@@ -18,6 +18,8 @@ from pathlib import Path
from typing import (
ClassVar,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
......@@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
"""
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
......@@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self.seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
class FrameAnnotsEntry(TypedDict):
subset: Optional[str]
......
......@@ -7,7 +7,7 @@
import warnings
from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple
from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
......@@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys())
self.seq_names = list(self.dataset.sequence_names())
def __len__(self) -> int:
return self.num_batches
......@@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if self.sample_consecutive_frames:
frame_idx = []
for seq in chosen_seq:
segment_index = self._build_segment_index(
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment_index = self._build_segment_index(seq, n_per_seq)
segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq:
......@@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
else:
frame_idx = [
_capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
list(self.dataset.sequence_indices_in_order(seq)),
n_per_seq,
replace=False,
)
for seq in chosen_seq
]
......@@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
)
return frame_idx
def _build_segment_index(
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
......@@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0
):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
seq_frame_indices, self.dataset
segments = self._split_to_segments(
self.dataset.sequence_frames_in_order(seq)
)
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size)
if not segments:
raise AssertionError("Empty segments after culling")
else:
segments = [seq_frame_indices]
segments = [list(self.dataset.sequence_indices_in_order(seq))]
# build an index of segment for random selection of a pivot frame
segment_index = [
......@@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
return segment_index
def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]]
self, sequence_timestamps: Iterable[Tuple[float, int, int]]
) -> List[List[int]]:
if (
self.consecutive_frames_max_gap <= 0
......@@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no:
raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given"
"Sequence frames are not ordered while timestamps are not given"
)
if (
......@@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths)
......
......@@ -9,6 +9,7 @@ import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
......@@ -18,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0
class MockDataset:
class MockDataset(ImplicitronDatasetBase):
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment