_deprecated.py 738 Bytes
Newer Older
1
import warnings
2
from typing import Any, Dict, Union
3

4
5
import numpy as np
import PIL.Image
6
import torch
7
from torchvision.transforms import functional as _F
8

9
10
from torchvision.transforms.v2 import Transform

11

12
class ToTensor(Transform):
13
    _transformed_types = (PIL.Image.Image, np.ndarray)
14

15
16
17
    def __init__(self) -> None:
        warnings.warn(
            "The transform `ToTensor()` is deprecated and will be removed in a future release. "
18
            "Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
19
20
21
        )
        super().__init__()

22
    def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
23
        return _F.to_tensor(inpt)