# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
fromdataclassesimportdataclass
fromtypingimportList,Optional,Tuple,Union
importtorch
fromPILimportImageasPILImage
fromtensordictimporttensorclass
@tensorclass
classBatchedVideoMetaData:
"""
This class represents metadata about a batch of videos.
Attributes:
unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
"""
unique_objects_identifier:torch.LongTensor
frame_orig_size:torch.LongTensor
@tensorclass
classBatchedVideoDatapoint:
"""
This class represents a batch of videos with associated annotations and metadata.
Attributes:
img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
dict_key: A string key used to identify the batch.