download_cifar10.py 265 Bytes
Newer Older
1
2
3
4
5
6
7
import os

from torchvision.datasets import CIFAR10


def main():
    dir_path = os.path.dirname(os.path.realpath(__file__))
8
    data_root = os.path.join(dir_path, "data")
9
10
11
    dataset = CIFAR10(root=data_root, download=True)


12
if __name__ == "__main__":
13
    main()