Commit c0bb49b5 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

API for accessing frames in order in Implicitron dataset.

Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility.

Reviewed By: davnov134

Differential Revision: D35012121

fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
parent 05f656c0
...@@ -18,6 +18,8 @@ from pathlib import Path ...@@ -18,6 +18,8 @@ from pathlib import Path
from typing import ( from typing import (
ClassVar, ClassVar,
Dict, Dict,
Iterable,
Iterator,
List, List,
Optional, Optional,
Sequence, Sequence,
...@@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): ...@@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
This means they have a __getitem__ which returns an instance of a FrameData, This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence. which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
""" """
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False) seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int: def __len__(self) -> int:
...@@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): ...@@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def get_eval_batches(self) -> Optional[List[List[int]]]: def get_eval_batches(self) -> Optional[List[List[int]]]:
return None return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self.seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
class FrameAnnotsEntry(TypedDict): class FrameAnnotsEntry(TypedDict):
subset: Optional[str] subset: Optional[str]
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np import numpy as np
from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler
...@@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if len(self.images_per_seq_options) < 1: if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty") raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys()) self.seq_names = list(self.dataset.sequence_names())
def __len__(self) -> int: def __len__(self) -> int:
return self.num_batches return self.num_batches
...@@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if self.sample_consecutive_frames: if self.sample_consecutive_frames:
frame_idx = [] frame_idx = []
for seq in chosen_seq: for seq in chosen_seq:
segment_index = self._build_segment_index( segment_index = self._build_segment_index(seq, n_per_seq)
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment, idx = segment_index[np.random.randint(len(segment_index))] segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq: if len(segment) <= n_per_seq:
...@@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
else: else:
frame_idx = [ frame_idx = [
_capped_random_choice( _capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False list(self.dataset.sequence_indices_in_order(seq)),
n_per_seq,
replace=False,
) )
for seq in chosen_seq for seq in chosen_seq
] ]
...@@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
) )
return frame_idx return frame_idx
def _build_segment_index( def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
""" """
Returns a list of (segment, index) tuples, one per eligible frame, where Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame segment is a list of frame indices in the contiguous segment the frame
...@@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
self.consecutive_frames_max_gap > 0 self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0 or self.consecutive_frames_max_gap_seconds > 0.0
): ):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers( segments = self._split_to_segments(
seq_frame_indices, self.dataset self.dataset.sequence_frames_in_order(seq)
) )
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size) segments = _cull_short_segments(segments, size)
if not segments: if not segments:
raise AssertionError("Empty segments after culling") raise AssertionError("Empty segments after culling")
else: else:
segments = [seq_frame_indices] segments = [list(self.dataset.sequence_indices_in_order(seq))]
# build an index of segment for random selection of a pivot frame # build an index of segment for random selection of a pivot frame
segment_index = [ segment_index = [
...@@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
return segment_index return segment_index
def _split_to_segments( def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]] self, sequence_timestamps: Iterable[Tuple[float, int, int]]
) -> List[List[int]]: ) -> List[List[int]]:
if ( if (
self.consecutive_frames_max_gap <= 0 self.consecutive_frames_max_gap <= 0
...@@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
for ts, no, idx in sequence_timestamps: for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no: if ts <= 0.0 and no <= last_no:
raise AssertionError( raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given" "Sequence frames are not ordered while timestamps are not given"
) )
if ( if (
...@@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]): ...@@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
return segments return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]: def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments] lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths) max_len, longest_segment = max(lengths)
......
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
...@@ -18,7 +19,7 @@ class MockFrameAnnotation: ...@@ -18,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0 frame_timestamp: float = 0.0
class MockDataset: class MockDataset(ImplicitronDatasetBase):
def __init__(self, num_seq, max_frame_gap=1): def __init__(self, num_seq, max_frame_gap=1):
""" """
Makes a gap of max_frame_gap frame numbers in the middle of each sequence Makes a gap of max_frame_gap frame numbers in the middle of each sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment