"vscode:/vscode.git/clone" did not exist on "b170646991a06cb18b1bd4e74efcd095f5b00c18"
__init__.py 1.4 KB
Newer Older
1
import torch
2
3
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

4
from ._bounding_box import BoundingBoxes, BoundingBoxFormat
5
6
from ._datapoint import Datapoint
from ._image import Image
7
from ._mask import Mask
8
from ._torch_function_helpers import set_return_type
9
from ._video import Video
10

11
12
13
14
if _WARN_ABOUT_BETA_TRANSFORMS:
    import warnings

    warnings.warn(_BETA_TRANSFORMS_WARNING)
15
16
17


def wrap(wrappee, *, like, **kwargs):
Nicolas Hug's avatar
Nicolas Hug committed
18
    """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoints.Datapoint` subclass as ``like``.
19

Nicolas Hug's avatar
Nicolas Hug committed
20
    If ``like`` is a :class:`~torchvision.datapoints.BoundingBoxes`, the ``format`` and ``canvas_size`` of
21
22
23
24
    ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.

    Args:
        wrappee (Tensor): The tensor to convert.
Nicolas Hug's avatar
Nicolas Hug committed
25
26
        like (:class:`~torchvision.datapoints.Datapoint`): The reference.
            ``wrappee`` will be converted into the same subclass as ``like``.
27
28
29
30
31
32
33
34
35
36
37
        kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`.
            Ignored otherwise.
    """
    if isinstance(like, BoundingBoxes):
        return BoundingBoxes._wrap(
            wrappee,
            format=kwargs.get("format", like.format),
            canvas_size=kwargs.get("canvas_size", like.canvas_size),
        )
    else:
        return wrappee.as_subclass(type(like))