cifar.py 884 Bytes
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
2
import torch
import torchvision.datasets as dset
3
import torchvision.transforms as transforms
Soumith Chintala's avatar
Soumith Chintala committed
4

Soumith Chintala's avatar
Soumith Chintala committed
5
6
print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True)
Soumith Chintala's avatar
Soumith Chintala committed
7

Soumith Chintala's avatar
Soumith Chintala committed
8
print(a[3])
Soumith Chintala's avatar
Soumith Chintala committed
9

10
11
# print('\n\nCifar 100')
# a = dset.CIFAR100(root="abc/def/ghi", download=True)
Soumith Chintala's avatar
Soumith Chintala committed
12

13
14
15
16
# print(a[3])


dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
Soumith Chintala's avatar
Soumith Chintala committed
17
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, 
18
19
20
                                         shuffle=True, num_workers=2)


Soumith Chintala's avatar
Soumith Chintala committed
21
22
23
24
25
for i, data in enumerate(dataloader, 0):
    print(data)
    if i == 10:
        break

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# miter = dataloader.__iter__()
# def getBatch():
#     global miter
#     try:
#         return miter.next()
#     except StopIteration:
#         miter = dataloader.__iter__()
#         return miter.next()
    
# i=0
# while True:
#     print(i)
#     img, target = getBatch()
#     i+=1