mnist.py 21.2 KB
Newer Older
1
import codecs
Tian Qi Chen's avatar
Tian Qi Chen committed
2
3
import os
import os.path
4
import shutil
5
import string
6
import sys
7
import warnings
8
9
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
from urllib.error import URLError
11
12
13
14
15

import numpy as np
import torch
from PIL import Image

Philip Meier's avatar
Philip Meier committed
16
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
17
from .vision import VisionDataset
Tian Qi Chen's avatar
Tian Qi Chen committed
18

19

20
class MNIST(VisionDataset):
21
22
23
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
24
        root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
25
26
27
28
            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
29
30
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
anthony-cabacungan's avatar
anthony-cabacungan committed
31
        transform (callable, optional): A function/transform that  takes in a PIL image
32
33
34
35
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
36

37
    mirrors = [
38
39
        "http://yann.lecun.com/exdb/mnist/",
        "https://ossci-datasets.s3.amazonaws.com/mnist/",
40
41
    ]

42
    resources = [
43
44
45
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
46
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
Tian Qi Chen's avatar
Tian Qi Chen committed
47
    ]
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
    training_file = "training.pt"
    test_file = "test.pt"
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

84
    def __init__(
85
        self,
86
        root: Union[str, Path],
87
88
89
90
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
91
    ) -> None:
92
        super().__init__(root, transform=transform, target_transform=target_transform)
93
        self.train = train  # training set or test set
Tian Qi Chen's avatar
Tian Qi Chen committed
94

95
96
97
98
        if self._check_legacy_exist():
            self.data, self.targets = self._load_legacy_data()
            return

Tian Qi Chen's avatar
Tian Qi Chen committed
99
100
101
102
        if download:
            self.download()

        if not self._check_exists():
103
            raise RuntimeError("Dataset not found. You can use download=True to download it")
Tian Qi Chen's avatar
Tian Qi Chen committed
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self.data, self.targets = self._load_data()

    def _check_legacy_exist(self):
        processed_folder_exists = os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False

        return all(
            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
        )

    def _load_legacy_data(self):
        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
        # directly.
        data_file = self.training_file if self.train else self.test_file
120
        return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
121
122
123
124
125
126
127
128
129

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets
Tian Qi Chen's avatar
Tian Qi Chen committed
130

131
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
132
133
134
135
136
137
138
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
139
        img, target = self.data[index], int(self.targets[index])
Tian Qi Chen's avatar
Tian Qi Chen committed
140
141
142

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
143
        img = Image.fromarray(img.numpy(), mode="L")
Tian Qi Chen's avatar
Tian Qi Chen committed
144
145
146
147
148
149
150
151
152

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

153
    def __len__(self) -> int:
154
        return len(self.data)
Tian Qi Chen's avatar
Tian Qi Chen committed
155

156
    @property
157
    def raw_folder(self) -> str:
158
        return os.path.join(self.root, self.__class__.__name__, "raw")
159
160

    @property
161
    def processed_folder(self) -> str:
162
        return os.path.join(self.root, self.__class__.__name__, "processed")
163
164

    @property
165
    def class_to_idx(self) -> Dict[str, int]:
166
167
        return {_class: i for i, _class in enumerate(self.classes)}

168
    def _check_exists(self) -> bool:
169
170
171
172
        return all(
            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
            for url, _ in self.resources
        )
173

174
    def download(self) -> None:
175
        """Download the MNIST data if it doesn't exist already."""
Tian Qi Chen's avatar
Tian Qi Chen committed
176
177
178
179

        if self._check_exists():
            return

180
        os.makedirs(self.raw_folder, exist_ok=True)
Tian Qi Chen's avatar
Tian Qi Chen committed
181

182
        # download files
183
184
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
185
                url = f"{mirror}{filename}"
186
                try:
187
                    print(f"Downloading {url}")
188
                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
189
                except URLError as error:
190
                    print(f"Failed to download (trying next):\n{error}")
191
192
193
194
195
                    continue
                finally:
                    print()
                break
            else:
196
                raise RuntimeError(f"Error downloading {filename}")
Tian Qi Chen's avatar
Tian Qi Chen committed
197

198
    def extra_repr(self) -> str:
199
200
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
201

202

203
class FashionMNIST(MNIST):
204
205
206
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
207
        root (str or ``pathlib.Path``): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
208
209
210
211
            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
212
213
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
anthony-cabacungan's avatar
anthony-cabacungan committed
214
        transform (callable, optional): A function/transform that  takes in a PIL image
215
216
217
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
218
    """
219
220

    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
221

222
    resources = [
223
224
225
        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
226
        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
227
    ]
228
    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
229
230


hysts's avatar
hysts committed
231
232
233
234
class KMNIST(MNIST):
    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
235
        root (str or ``pathlib.Path``): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
236
237
238
239
            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
hysts's avatar
hysts committed
240
241
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
anthony-cabacungan's avatar
anthony-cabacungan committed
242
        transform (callable, optional): A function/transform that  takes in a PIL image
hysts's avatar
hysts committed
243
244
245
246
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
247
248

    mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
249

250
    resources = [
251
252
253
        ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
        ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
        ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
254
        ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
hysts's avatar
hysts committed
255
    ]
256
    classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
hysts's avatar
hysts committed
257
258


259
class EMNIST(MNIST):
Alex Alemi's avatar
Alex Alemi committed
260
    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
261
262

    Args:
263
        root (str or ``pathlib.Path``): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
264
            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
265
266
267
268
269
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
270
        download (bool, optional): If True, downloads the dataset from the internet and
271
272
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
anthony-cabacungan's avatar
anthony-cabacungan committed
273
        transform (callable, optional): A function/transform that  takes in a PIL image
274
275
276
277
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
278

279
    url = "https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip"
280
    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
281
    splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
282
    # Merged Classes assumes Same structure for both uppercase and lowercase version
283
    _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
284
    _all_classes = set(string.digits + string.ascii_letters)
285
    classes_split_dict = {
286
287
288
289
290
291
        "byclass": sorted(list(_all_classes)),
        "bymerge": sorted(list(_all_classes - _merged_classes)),
        "balanced": sorted(list(_all_classes - _merged_classes)),
        "letters": ["N/A"] + list(string.ascii_lowercase),
        "digits": list(string.digits),
        "mnist": list(string.digits),
292
    }
293

294
    def __init__(self, root: Union[str, Path], split: str, **kwargs: Any) -> None:
295
        self.split = verify_str_arg(split, "split", self.splits)
296
297
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
298
        super().__init__(root, **kwargs)
299
        self.classes = self.classes_split_dict[self.split]
Tian Qi Chen's avatar
Tian Qi Chen committed
300

301
    @staticmethod
302
    def _training_file(split) -> str:
303
        return f"training_{split}.pt"
304

305
    @staticmethod
306
    def _test_file(split) -> str:
307
        return f"test_{split}.pt"
308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    @property
    def _file_prefix(self) -> str:
        return f"emnist-{self.split}-{'train' if self.train else 'test'}"

    @property
    def images_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")

    @property
    def labels_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")

    def _load_data(self):
        return read_image_file(self.images_file), read_label_file(self.labels_file)

    def _check_exists(self) -> bool:
        return all(check_integrity(file) for file in (self.images_file, self.labels_file))

327
    def download(self) -> None:
328
        """Download the EMNIST data if it doesn't exist already."""
329

330
331
332
        if self._check_exists():
            return

333
        os.makedirs(self.raw_folder, exist_ok=True)
334

335
        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
336
        gzip_folder = os.path.join(self.raw_folder, "gzip")
337
        for gzip_file in os.listdir(gzip_folder):
338
            if gzip_file.endswith(".gz"):
339
                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
340
        shutil.rmtree(gzip_folder)
341
342


343
344
345
346
class QMNIST(MNIST):
    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

    Args:
347
        root (str or ``pathlib.Path``): Root directory of dataset whose ``raw``
348
            subdir contains binary files of the datasets.
349
350
351
352
353
354
355
356
357
358
359
        what (string,optional): Can be 'train', 'test', 'test10k',
            'test50k', or 'nist' for respectively the mnist compatible
            training set, the 60k qmnist testing set, the 10k qmnist
            examples that match the mnist testing set, the 50k
            remaining qmnist testing examples, or all the nist
            digits. The default is to select 'train' or 'test'
            according to the compatibility argument 'train'.
        compat (bool,optional): A boolean that says whether the target
            for each example is class number (for compatibility with
            the MNIST dataloader) or a torch vector containing the
            full qmnist information. Default=True.
360
        download (bool, optional): If True, downloads the dataset from
361
362
363
            the internet and puts it in root directory. If dataset is
            already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that
anthony-cabacungan's avatar
anthony-cabacungan committed
364
            takes in a PIL image and returns a transformed
365
366
367
368
369
            version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform
            that takes in the target and transforms it.
        train (bool,optional,compatibility): When argument 'what' is
            not specified, this boolean decides whether to load the
370
            training set or the testing set.  Default: True.
371
372
    """

373
    subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
374
    resources: Dict[str, List[Tuple[str, str]]] = {  # type: ignore[assignment]
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        "train": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
                "ed72d4157d28c017586c42bc6afe6370",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
                "0058f8dd561b90ffdd0f734c6a30e5e4",
            ),
        ],
        "test": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
                "1394631089c404de565df7b7aeaf9412",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
                "5b5b05890a5e13444e108efe57b788aa",
            ),
        ],
        "nist": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
                "7f124b3b8ab81486c9d8c2749c17f834",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
                "5ed0e788978e45d4a8bd4b7caec3d79d",
            ),
        ],
405
    }
406
407
408
409
410
411
412
413
414
415
416
417
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
418

419
    def __init__(
420
        self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
421
    ) -> None:
422
        if what is None:
423
            what = "train" if train else "test"
424
        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
425
        self.compat = compat
426
        self.data_file = what + ".pt"
427
428
        self.training_file = self.data_file
        self.test_file = self.data_file
429
        super().__init__(root, train, **kwargs)
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    @property
    def images_file(self) -> str:
        (url, _), _ = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    @property
    def labels_file(self) -> str:
        _, (url, _) = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    def _check_exists(self) -> bool:
        return all(check_integrity(file) for file in (self.images_file, self.labels_file))

    def _load_data(self):
        data = read_sn3_pascalvincent_tensor(self.images_file)
446
447
448
449
        if data.dtype != torch.uint8:
            raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
        if data.ndimension() != 3:
            raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
450
451

        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
452
453
        if targets.ndimension() != 2:
            raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
454

455
        if self.what == "test10k":
456
457
            data = data[0:10000, :, :].clone()
            targets = targets[0:10000, :].clone()
458
        elif self.what == "test50k":
459
460
461
462
463
            data = data[10000:, :, :].clone()
            targets = targets[10000:, :].clone()

        return data, targets

464
    def download(self) -> None:
465
        """Download the QMNIST data if it doesn't exist already.
466
        Note that we only download what has been asked for (argument 'what').
467
468
469
        """
        if self._check_exists():
            return
470

471
        os.makedirs(self.raw_folder, exist_ok=True)
472
        split = self.resources[self.subsets[self.what]]
473

474
        for url, md5 in split:
475
            download_and_extract_archive(url, self.raw_folder, md5=md5)
476

477
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
478
479
        # redefined to handle the compat flag
        img, target = self.data[index], self.targets[index]
480
        img = Image.fromarray(img.numpy(), mode="L")
481
482
483
484
485
486
487
488
        if self.transform is not None:
            img = self.transform(img)
        if self.compat:
            target = int(target[0])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

489
    def extra_repr(self) -> str:
490
        return f"Split: {self.what}"
491
492


493
def get_int(b: bytes) -> int:
494
    return int(codecs.encode(b, "hex"), 16)
Tian Qi Chen's avatar
Tian Qi Chen committed
495

496

497
SN3_PASCALVINCENT_TYPEMAP = {
498
499
500
501
502
503
    8: torch.uint8,
    9: torch.int8,
    11: torch.int16,
    12: torch.int32,
    13: torch.float32,
    14: torch.float64,
504
505
506
}


507
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
508
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
509
    Argument may be a filename, compressed filename, or file object.
510
511
    """
    # read
512
    with open(path, "rb") as f:
513
        data = f.read()
514

515
    # parse
516
517
518
519
520
521
522
523
    if sys.byteorder == "little":
        magic = get_int(data[0:4])
        nd = magic % 256
        ty = magic // 256
    else:
        nd = get_int(data[0:1])
        ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256

524
525
    assert 1 <= nd <= 3
    assert 8 <= ty <= 14
526
    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
527
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
528

529
530
531
532
    if sys.byteorder == "big":
        for i in range(len(s)):
            s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)

533
    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
Philip Meier's avatar
Philip Meier committed
534
535
536
537
538

    # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
    # that is little endian and the dtype has more than one byte, we need to flip them.
    if sys.byteorder == "little" and parsed.element_size() > 1:
        parsed = _flip_byte_order(parsed)
539

540
    assert parsed.shape[0] == np.prod(s) or not strict
541
    return parsed.view(*s)
542
543


544
def read_label_file(path: str) -> torch.Tensor:
545
    x = read_sn3_pascalvincent_tensor(path, strict=False)
546
547
548
549
    if x.dtype != torch.uint8:
        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
    if x.ndimension() != 1:
        raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
550
    return x.long()
Tian Qi Chen's avatar
Tian Qi Chen committed
551

552

553
def read_image_file(path: str) -> torch.Tensor:
554
    x = read_sn3_pascalvincent_tensor(path, strict=False)
555
556
557
558
    if x.dtype != torch.uint8:
        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
    if x.ndimension() != 3:
        raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
559
    return x