Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a137e4f4
Commit
a137e4f4
authored
Feb 27, 2017
by
Luke Yeager
Committed by
Soumith Chintala
Feb 27, 2017
Browse files
[Lint] Fix most lint automatically with autopep8
git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i
parent
e9ec6ac8
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
91 additions
and
49 deletions
+91
-49
setup.cfg
setup.cfg
+8
-0
test/cifar.py
test/cifar.py
+2
-3
test/preprocess-bench.py
test/preprocess-bench.py
+2
-4
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+17
-17
torchvision/datasets/coco.py
torchvision/datasets/coco.py
+6
-2
torchvision/datasets/folder.py
torchvision/datasets/folder.py
+4
-0
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+15
-11
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+10
-4
torchvision/models/alexnet.py
torchvision/models/alexnet.py
+1
-0
torchvision/models/resnet.py
torchvision/models/resnet.py
+1
-0
torchvision/models/squeezenet.py
torchvision/models/squeezenet.py
+3
-1
torchvision/models/vgg.py
torchvision/models/vgg.py
+1
-0
torchvision/transforms.py
torchvision/transforms.py
+14
-1
torchvision/utils.py
torchvision/utils.py
+7
-6
No files found.
setup.cfg
View file @
a137e4f4
[bdist_wheel]
universal=1
[pep8]
max-line-length = 120
[flake8]
max-line-length = 120
ignore = F401,F403
exclude = venv
test/cifar.py
View file @
a137e4f4
...
...
@@ -14,7 +14,7 @@ print(a[3])
dataset
=
dset
.
CIFAR10
(
root
=
'cifar'
,
download
=
True
,
transform
=
transforms
.
ToTensor
())
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
1
,
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
True
,
num_workers
=
2
)
...
...
@@ -31,10 +31,9 @@ for i, data in enumerate(dataloader, 0):
# except StopIteration:
# miter = dataloader.__iter__()
# return miter.next()
# i=0
# while True:
# print(i)
# img, target = getBatch()
# i+=1
test/preprocess-bench.py
View file @
a137e4f4
...
...
@@ -20,14 +20,13 @@ parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
# Data loading code
transform
=
transforms
.
Compose
([
transforms
.
RandomSizedCrop
(
224
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
traindir
=
os
.
path
.
join
(
args
.
data
,
'train'
)
...
...
@@ -47,4 +46,3 @@ if __name__ == "__main__":
dataset
=
(
end_time
-
start_time
)
*
(
float
(
len
(
train_loader
))
/
batch_count
/
60.0
),
batch
=
(
end_time
-
start_time
)
/
float
(
batch_count
),
image
=
(
end_time
-
start_time
)
/
(
batch_count
*
args
.
batchSize
)
*
1.0e+3
))
torchvision/datasets/cifar.py
View file @
a137e4f4
...
...
@@ -11,36 +11,37 @@ if sys.version_info[0] == 2:
else
:
import
pickle
class
CIFAR10
(
data
.
Dataset
):
base_folder
=
'cifar-10-batches-py'
url
=
"http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename
=
"cifar-10-python.tar.gz"
tgz_md5
=
'c58f30108f718f92721af3b95e74349a'
train_list
=
[
[
'data_batch_1'
,
'c99cafc152244af753f735de768cd75f'
],
[
'data_batch_2'
,
'd4bba439e000b95fd0a9bffe97cbabec'
],
[
'data_batch_3'
,
'54ebc095f3ab1f0389bbae665268c751'
],
[
'data_batch_4'
,
'634d18415352ddfa80567beed471001a'
],
[
'data_batch_5'
,
'482c414d41f54cd18b22e5b47cb7c3cb'
],
[
'data_batch_1'
,
'c99cafc152244af753f735de768cd75f'
],
[
'data_batch_2'
,
'd4bba439e000b95fd0a9bffe97cbabec'
],
[
'data_batch_3'
,
'54ebc095f3ab1f0389bbae665268c751'
],
[
'data_batch_4'
,
'634d18415352ddfa80567beed471001a'
],
[
'data_batch_5'
,
'482c414d41f54cd18b22e5b47cb7c3cb'
],
]
test_list
=
[
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
]
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
train
=
train
# training set or test set
self
.
train
=
train
# training set or test set
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted.'
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
# now load the picked numpy arrays
if
self
.
train
:
self
.
train_data
=
[]
...
...
@@ -83,10 +84,10 @@ class CIFAR10(data.Dataset):
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img
=
Image
.
fromarray
(
np
.
transpose
(
img
,
(
1
,
2
,
0
)))
img
=
Image
.
fromarray
(
np
.
transpose
(
img
,
(
1
,
2
,
0
)))
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
...
...
@@ -134,7 +135,7 @@ class CIFAR10(data.Dataset):
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
# downloads file
if
os
.
path
.
isfile
(
fpath
)
and
\
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
==
self
.
tgz_md5
:
...
...
@@ -147,7 +148,7 @@ class CIFAR10(data.Dataset):
cwd
=
os
.
getcwd
()
print
(
'Extracting tar file'
)
tar
=
tarfile
.
open
(
fpath
,
"r:gz"
)
os
.
chdir
(
root
)
os
.
chdir
(
root
)
tar
.
extractall
()
tar
.
close
()
os
.
chdir
(
cwd
)
...
...
@@ -160,10 +161,9 @@ class CIFAR100(CIFAR10):
filename
=
"cifar-100-python.tar.gz"
tgz_md5
=
'eb9058c3a382ffc7106e4002c42a8d85'
train_list
=
[
[
'train'
,
'16019d7e3df5f24257cddd939b257f8d'
],
[
'train'
,
'16019d7e3df5f24257cddd939b257f8d'
],
]
test_list
=
[
[
'test'
,
'f0ef6b0ae62326f3e7ffdfab6717acfc'
],
[
'test'
,
'f0ef6b0ae62326f3e7ffdfab6717acfc'
],
]
torchvision/datasets/coco.py
View file @
a137e4f4
...
...
@@ -3,7 +3,9 @@ from PIL import Image
import
os
import
os.path
class
CocoCaptions
(
data
.
Dataset
):
def
__init__
(
self
,
root
,
annFile
,
transform
=
None
,
target_transform
=
None
):
from
pycocotools.coco
import
COCO
self
.
root
=
root
...
...
@@ -15,7 +17,7 @@ class CocoCaptions(data.Dataset):
def
__getitem__
(
self
,
index
):
coco
=
self
.
coco
img_id
=
self
.
ids
[
index
]
ann_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
)
ann_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
)
anns
=
coco
.
loadAnns
(
ann_ids
)
target
=
[
ann
[
'caption'
]
for
ann
in
anns
]
...
...
@@ -33,7 +35,9 @@ class CocoCaptions(data.Dataset):
def
__len__
(
self
):
return
len
(
self
.
ids
)
class
CocoDetection
(
data
.
Dataset
):
def
__init__
(
self
,
root
,
annFile
,
transform
=
None
,
target_transform
=
None
):
from
pycocotools.coco
import
COCO
self
.
root
=
root
...
...
@@ -45,7 +49,7 @@ class CocoDetection(data.Dataset):
def
__getitem__
(
self
,
index
):
coco
=
self
.
coco
img_id
=
self
.
ids
[
index
]
ann_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
)
ann_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
)
target
=
coco
.
loadAnns
(
ann_ids
)
path
=
coco
.
loadImgs
(
img_id
)[
0
][
'file_name'
]
...
...
torchvision/datasets/folder.py
View file @
a137e4f4
...
...
@@ -9,15 +9,18 @@ IMG_EXTENSIONS = [
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
find_classes
(
dir
):
classes
=
os
.
listdir
(
dir
)
classes
.
sort
()
class_to_idx
=
{
classes
[
i
]:
i
for
i
in
range
(
len
(
classes
))}
return
classes
,
class_to_idx
def
make_dataset
(
dir
,
class_to_idx
):
images
=
[]
for
target
in
os
.
listdir
(
dir
):
...
...
@@ -39,6 +42,7 @@ def default_loader(path):
class
ImageFolder
(
data
.
Dataset
):
def
__init__
(
self
,
root
,
transform
=
None
,
target_transform
=
None
,
loader
=
default_loader
):
classes
,
class_to_idx
=
find_classes
(
root
)
...
...
torchvision/datasets/lsun.py
View file @
a137e4f4
...
...
@@ -10,7 +10,9 @@ if sys.version_info[0] == 2:
else
:
import
pickle
class
LSUNClass
(
data
.
Dataset
):
def
__init__
(
self
,
db_path
,
transform
=
None
,
target_transform
=
None
):
import
lmdb
self
.
db_path
=
db_path
...
...
@@ -20,11 +22,11 @@ class LSUNClass(data.Dataset):
self
.
length
=
txn
.
stat
()[
'entries'
]
cache_file
=
'_cache_'
+
db_path
.
replace
(
'/'
,
'_'
)
if
os
.
path
.
isfile
(
cache_file
):
self
.
keys
=
pickle
.
load
(
open
(
cache_file
,
"rb"
)
)
self
.
keys
=
pickle
.
load
(
open
(
cache_file
,
"rb"
)
)
else
:
with
self
.
env
.
begin
(
write
=
False
)
as
txn
:
self
.
keys
=
[
key
for
key
,
_
in
txn
.
cursor
()
]
pickle
.
dump
(
self
.
keys
,
open
(
cache_file
,
"wb"
)
)
self
.
keys
=
[
key
for
key
,
_
in
txn
.
cursor
()]
pickle
.
dump
(
self
.
keys
,
open
(
cache_file
,
"wb"
)
)
self
.
transform
=
transform
self
.
target_transform
=
target_transform
...
...
@@ -53,11 +55,13 @@ class LSUNClass(data.Dataset):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
' ('
+
self
.
db_path
+
')'
class
LSUN
(
data
.
Dataset
):
"""
db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
"""
def
__init__
(
self
,
db_path
,
classes
=
'train'
,
transform
=
None
,
target_transform
=
None
):
categories
=
[
'bedroom'
,
'bridge'
,
'church_outdoor'
,
'classroom'
,
...
...
@@ -73,13 +77,13 @@ class LSUN(data.Dataset):
c_short
.
pop
(
len
(
c_short
)
-
1
)
c_short
=
'_'
.
join
(
c_short
)
if
c_short
not
in
categories
:
raise
(
ValueError
(
'Unknown LSUN class: '
+
c_short
+
'.'
\
'Options are: '
+
str
(
categories
)))
raise
(
ValueError
(
'Unknown LSUN class: '
+
c_short
+
'.'
'Options are: '
+
str
(
categories
)))
c_short
=
c
.
split
(
'_'
)
c_short
=
c_short
.
pop
(
len
(
c_short
)
-
1
)
if
c_short
not
in
dset_opts
:
raise
(
ValueError
(
'Unknown postfix: '
+
c_short
+
'.'
\
'Options are: '
+
str
(
dset_opts
)))
raise
(
ValueError
(
'Unknown postfix: '
+
c_short
+
'.'
'Options are: '
+
str
(
dset_opts
)))
else
:
raise
(
ValueError
(
'Unknown option for classes'
))
self
.
classes
=
classes
...
...
@@ -88,8 +92,8 @@ class LSUN(data.Dataset):
self
.
dbs
=
[]
for
c
in
self
.
classes
:
self
.
dbs
.
append
(
LSUNClass
(
db_path
=
db_path
+
'/'
+
c
+
'_lmdb'
,
transform
=
transform
))
db_path
=
db_path
+
'/'
+
c
+
'_lmdb'
,
transform
=
transform
))
self
.
indices
=
[]
count
=
0
...
...
@@ -128,9 +132,9 @@ if __name__ == '__main__':
#lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
#a = lsun[0]
lsun
=
LSUN
(
db_path
=
'/home/soumith/local/lsun/train'
,
classes
=
[
'bedroom_train'
,
'church_outdoor_train'
])
classes
=
[
'bedroom_train'
,
'church_outdoor_train'
])
print
(
lsun
.
classes
)
print
(
lsun
.
dbs
)
a
,
t
=
lsun
[
len
(
lsun
)
-
1
]
a
,
t
=
lsun
[
len
(
lsun
)
-
1
]
print
(
a
)
print
(
t
)
torchvision/datasets/mnist.py
View file @
a137e4f4
...
...
@@ -9,6 +9,7 @@ import json
import
codecs
import
numpy
as
np
class
MNIST
(
data
.
Dataset
):
urls
=
[
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
,
...
...
@@ -25,7 +26,7 @@ class MNIST(data.Dataset):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
train
=
train
# training set or test set
self
.
train
=
train
# training set or test set
if
download
:
self
.
download
()
...
...
@@ -35,7 +36,8 @@ class MNIST(data.Dataset):
+
' You can use download=True to download it'
)
if
self
.
train
:
self
.
train_data
,
self
.
train_labels
=
torch
.
load
(
os
.
path
.
join
(
root
,
self
.
processed_folder
,
self
.
training_file
))
self
.
train_data
,
self
.
train_labels
=
torch
.
load
(
os
.
path
.
join
(
root
,
self
.
processed_folder
,
self
.
training_file
))
else
:
self
.
test_data
,
self
.
test_labels
=
torch
.
load
(
os
.
path
.
join
(
root
,
self
.
processed_folder
,
self
.
test_file
))
...
...
@@ -65,7 +67,7 @@ class MNIST(data.Dataset):
def
_check_exists
(
self
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
and
\
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
test_file
))
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
test_file
))
def
download
(
self
):
from
six.moves
import
urllib
...
...
@@ -92,7 +94,7 @@ class MNIST(data.Dataset):
with
open
(
file_path
,
'wb'
)
as
f
:
f
.
write
(
data
.
read
())
with
open
(
file_path
.
replace
(
'.gz'
,
''
),
'wb'
)
as
out_f
,
\
gzip
.
GzipFile
(
file_path
)
as
zip_f
:
gzip
.
GzipFile
(
file_path
)
as
zip_f
:
out_f
.
write
(
zip_f
.
read
())
os
.
unlink
(
file_path
)
...
...
@@ -114,14 +116,17 @@ class MNIST(data.Dataset):
print
(
'Done!'
)
def
get_int
(
b
):
return
int
(
codecs
.
encode
(
b
,
'hex'
),
16
)
def
parse_byte
(
b
):
if
isinstance
(
b
,
str
):
return
ord
(
b
)
return
b
def
read_label_file
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
data
=
f
.
read
()
...
...
@@ -131,6 +136,7 @@ def read_label_file(path):
assert
len
(
labels
)
==
length
return
torch
.
LongTensor
(
labels
)
def
read_image_file
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
data
=
f
.
read
()
...
...
torchvision/models/alexnet.py
View file @
a137e4f4
...
...
@@ -11,6 +11,7 @@ model_urls = {
class
AlexNet
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
1000
):
super
(
AlexNet
,
self
).
__init__
()
self
.
features
=
nn
.
Sequential
(
...
...
torchvision/models/resnet.py
View file @
a137e4f4
...
...
@@ -94,6 +94,7 @@ class Bottleneck(nn.Module):
class
ResNet
(
nn
.
Module
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
self
.
inplanes
=
64
super
(
ResNet
,
self
).
__init__
()
...
...
torchvision/models/squeezenet.py
View file @
a137e4f4
...
...
@@ -14,8 +14,9 @@ model_urls = {
class
Fire
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
squeeze_planes
,
expand1x1_planes
,
expand3x3_planes
):
expand1x1_planes
,
expand3x3_planes
):
super
(
Fire
,
self
).
__init__
()
self
.
inplanes
=
inplanes
self
.
squeeze
=
nn
.
Conv2d
(
inplanes
,
squeeze_planes
,
kernel_size
=
1
)
...
...
@@ -36,6 +37,7 @@ class Fire(nn.Module):
class
SqueezeNet
(
nn
.
Module
):
def
__init__
(
self
,
version
=
1.0
,
num_classes
=
1000
):
super
(
SqueezeNet
,
self
).
__init__
()
if
version
not
in
[
1.0
,
1.1
]:
...
...
torchvision/models/vgg.py
View file @
a137e4f4
...
...
@@ -18,6 +18,7 @@ model_urls = {
class
VGG
(
nn
.
Module
):
def
__init__
(
self
,
features
):
super
(
VGG
,
self
).
__init__
()
self
.
features
=
features
...
...
torchvision/transforms.py
View file @
a137e4f4
...
...
@@ -7,6 +7,7 @@ import numpy as np
import
numbers
import
types
class
Compose
(
object
):
"""Composes several transforms together.
...
...
@@ -19,6 +20,7 @@ class Compose(object):
>>> transforms.ToTensor(),
>>> ])
"""
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
...
...
@@ -32,6 +34,7 @@ class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def
__call__
(
self
,
pic
):
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
...
...
@@ -56,12 +59,13 @@ class ToPILImage(object):
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""
def
__call__
(
self
,
pic
):
npimg
=
pic
mode
=
None
if
not
isinstance
(
npimg
,
np
.
ndarray
):
npimg
=
pic
.
mul
(
255
).
byte
().
numpy
()
npimg
=
np
.
transpose
(
npimg
,
(
1
,
2
,
0
))
npimg
=
np
.
transpose
(
npimg
,
(
1
,
2
,
0
))
if
npimg
.
shape
[
2
]
==
1
:
npimg
=
npimg
[:,
:,
0
]
...
...
@@ -75,6 +79,7 @@ class Normalize(object):
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""
def
__init__
(
self
,
mean
,
std
):
self
.
mean
=
mean
self
.
std
=
std
...
...
@@ -94,6 +99,7 @@ class Scale(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def
__init__
(
self
,
size
,
interpolation
=
Image
.
BILINEAR
):
self
.
size
=
size
self
.
interpolation
=
interpolation
...
...
@@ -117,6 +123,7 @@ class CenterCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def
__init__
(
self
,
size
):
if
isinstance
(
size
,
numbers
.
Number
):
self
.
size
=
(
int
(
size
),
int
(
size
))
...
...
@@ -133,6 +140,7 @@ class CenterCrop(object):
class
Pad
(
object
):
"""Pads the given PIL.Image on all sides with the given "pad" value"""
def
__init__
(
self
,
padding
,
fill
=
0
):
assert
isinstance
(
padding
,
numbers
.
Number
)
assert
isinstance
(
fill
,
numbers
.
Number
)
or
isinstance
(
fill
,
str
)
or
isinstance
(
fill
,
tuple
)
...
...
@@ -142,8 +150,10 @@ class Pad(object):
def
__call__
(
self
,
img
):
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
fill
=
self
.
fill
)
class
Lambda
(
object
):
"""Applies a lambda as a transform."""
def
__init__
(
self
,
lambd
):
assert
type
(
lambd
)
is
types
.
LambdaType
self
.
lambd
=
lambd
...
...
@@ -157,6 +167,7 @@ class RandomCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def
__init__
(
self
,
size
,
padding
=
0
):
if
isinstance
(
size
,
numbers
.
Number
):
self
.
size
=
(
int
(
size
),
int
(
size
))
...
...
@@ -181,6 +192,7 @@ class RandomCrop(object):
class
RandomHorizontalFlip
(
object
):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def
__call__
(
self
,
img
):
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
...
...
@@ -194,6 +206,7 @@ class RandomSizedCrop(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def
__init__
(
self
,
size
,
interpolation
=
Image
.
BILINEAR
):
self
.
size
=
size
self
.
interpolation
=
interpolation
...
...
torchvision/utils.py
View file @
a137e4f4
import
torch
import
math
def
make_grid
(
tensor
,
nrow
=
8
,
padding
=
2
):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
...
...
@@ -15,13 +16,13 @@ def make_grid(tensor, nrow=8, padding=2):
tensor
=
tensorlist
[
0
].
new
(
size
)
for
i
in
range
(
numImages
):
tensor
[
i
].
copy_
(
tensorlist
[
i
])
if
tensor
.
dim
()
==
2
:
# single image H x W
if
tensor
.
dim
()
==
2
:
# single image H x W
tensor
=
tensor
.
view
(
1
,
tensor
.
size
(
0
),
tensor
.
size
(
1
))
if
tensor
.
dim
()
==
3
:
# single image
if
tensor
.
dim
()
==
3
:
# single image
if
tensor
.
size
(
0
)
==
1
:
tensor
=
torch
.
cat
((
tensor
,
tensor
,
tensor
),
0
)
return
tensor
if
tensor
.
dim
()
==
4
and
tensor
.
size
(
1
)
==
1
:
# single-channel images
if
tensor
.
dim
()
==
4
and
tensor
.
size
(
1
)
==
1
:
# single-channel images
tensor
=
torch
.
cat
((
tensor
,
tensor
,
tensor
),
1
)
# make the mini-batch of images into a grid
nmaps
=
tensor
.
size
(
0
)
...
...
@@ -34,8 +35,8 @@ def make_grid(tensor, nrow=8, padding=2):
for
x
in
range
(
xmaps
):
if
k
>=
nmaps
:
break
grid
.
narrow
(
1
,
y
*
height
+
1
+
padding
//
2
,
height
-
padding
)
\
.
narrow
(
2
,
x
*
width
+
1
+
padding
//
2
,
width
-
padding
)
\
grid
.
narrow
(
1
,
y
*
height
+
1
+
padding
//
2
,
height
-
padding
)
\
.
narrow
(
2
,
x
*
width
+
1
+
padding
//
2
,
width
-
padding
)
\
.
copy_
(
tensor
[
k
])
k
=
k
+
1
return
grid
...
...
@@ -49,6 +50,6 @@ def save_image(tensor, filename, nrow=8, padding=2):
from
PIL
import
Image
tensor
=
tensor
.
cpu
()
grid
=
make_grid
(
tensor
,
nrow
=
nrow
,
padding
=
padding
)
ndarr
=
grid
.
mul
(
0.5
).
add
(
0.5
).
mul
(
255
).
byte
().
transpose
(
0
,
2
).
transpose
(
0
,
1
).
numpy
()
ndarr
=
grid
.
mul
(
0.5
).
add
(
0.5
).
mul
(
255
).
byte
().
transpose
(
0
,
2
).
transpose
(
0
,
1
).
numpy
()
im
=
Image
.
fromarray
(
ndarr
)
im
.
save
(
filename
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment