image_list.py 783 Bytes
Newer Older
1
2
from typing import List, Tuple

3
import torch
eellison's avatar
eellison committed
4
from torch import Tensor
5
6


7
class ImageList:
8
9
10
11
12
    """
    Structure that holds a list of images (of possibly
    varying sizes) as a single tensor.
    This works by padding the images to the same size,
    and storing in a field the original sizes of each image
13
14
15
16

    Args:
        tensors (tensor): Tensor containing images.
        image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
17
18
    """

19
    def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
20
21
22
        self.tensors = tensors
        self.image_sizes = image_sizes

23
    def to(self, device: torch.device) -> "ImageList":
eellison's avatar
eellison committed
24
        cast_tensor = self.tensors.to(device)
25
        return ImageList(cast_tensor, self.image_sizes)