# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import Tensor from typing import List, Tuple class ImageList(object): """ Structure that holds a list of images (of possibly varying sizes) as a single tensor. This works by padding the images to the same size, and storing in a field the original sizes of each image """ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]): """ Args: tensors (tensor) image_sizes (list[tuple[int, int]]) """ self.tensors = tensors self.image_sizes = image_sizes def to(self, device: torch.device) -> 'ImageList': cast_tensor = self.tensors.to(device) return ImageList(cast_tensor, self.image_sizes)