Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
9bc70531
Unverified
Commit
9bc70531
authored
Jun 05, 2018
by
Hang Zhang
Committed by
GitHub
Jun 05, 2018
Browse files
add detail API and other fixes (#63)
parent
3ba8d2f7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
89 additions
and
50 deletions
+89
-50
README.md
README.md
+1
-1
encoding/datasets/base.py
encoding/datasets/base.py
+0
-3
encoding/models/base.py
encoding/models/base.py
+15
-27
encoding/models/fcn.py
encoding/models/fcn.py
+1
-1
encoding/nn/customize.py
encoding/nn/customize.py
+0
-1
experiments/recognition/main.py
experiments/recognition/main.py
+2
-2
experiments/segmentation/demo.py
experiments/segmentation/demo.py
+21
-0
experiments/segmentation/test.py
experiments/segmentation/test.py
+3
-3
experiments/segmentation/test_models.py
experiments/segmentation/test_models.py
+19
-0
experiments/segmentation/train.py
experiments/segmentation/train.py
+15
-12
scripts/prepare_pcontext.py
scripts/prepare_pcontext.py
+12
-0
No files found.
README.md
View file @
9bc70531
...
...
@@ -6,7 +6,7 @@ created by [Hang Zhang](http://hangzh.com/)
-
Please visit the
[
**Docs**
](
http://hangzh.com/PyTorch-Encoding/
)
for detail instructions of installation and usage.
-
How to use Synchronized Batch Normalization (SyncBN)? See the
[
examples
](
http
s
://
github.com/zhanghang1989/PyTorch-SyncBatchNorm
)
.
-
Please visit the
[
link
](
http://
hangzh.com/PyTorch-Encoding/experiments/segmentation.html
)
to examples of semantic segmentation
.
## Citations
...
...
encoding/datasets/base.py
View file @
9bc70531
...
...
@@ -106,7 +106,4 @@ def test_batchify_fn(data):
elif
isinstance
(
data
[
0
],
(
tuple
,
list
)):
data
=
zip
(
*
data
)
return
[
test_batchify_fn
(
i
)
for
i
in
data
]
elif
isinstance
(
data
[
0
],
):
data
=
np
.
asarray
(
data
)
return
mx
.
nd
.
array
(
data
,
dtype
=
data
.
dtype
)
raise
TypeError
((
error_msg
.
format
(
type
(
batch
[
0
]))))
encoding/models/base.py
View file @
9bc70531
...
...
@@ -20,7 +20,7 @@ from ..utils import batch_pix_accuracy, batch_intersection_union
up_kwargs
=
{
'mode'
:
'bilinear'
,
'align_corners'
:
True
}
__all__
=
[
'BaseNet'
,
'EvalModule'
,
'MultiEvalModule'
]
__all__
=
[
'BaseNet'
,
'MultiEvalModule'
]
class
BaseNet
(
nn
.
Module
):
def
__init__
(
self
,
nclass
,
backbone
,
aux
,
se_loss
,
dilated
=
True
,
norm_layer
=
None
,
...
...
@@ -65,16 +65,6 @@ class BaseNet(nn.Module):
return
correct
,
labeled
,
inter
,
union
class
EvalModule
(
nn
.
Module
):
"""Segmentation Eval Module"""
def
__init__
(
self
,
module
):
super
(
EvalModule
,
self
).
__init__
()
self
.
module
=
module
def
forward
(
self
,
*
inputs
,
**
kwargs
):
return
self
.
module
.
evaluate
(
*
inputs
,
**
kwargs
)
class
MultiEvalModule
(
DataParallel
):
"""Multi-size Segmentation Eavluator"""
def
__init__
(
self
,
module
,
nclass
,
device_ids
=
None
,
...
...
@@ -125,11 +115,11 @@ class MultiEvalModule(DataParallel):
height
=
int
(
1.0
*
h
*
long_size
/
w
+
0.5
)
short_size
=
height
# resize image to current size
cur_img
=
resize_image
(
image
,
height
,
width
)
if
scale
<=
1.25
or
long_size
<=
crop_size
:
# #
cur_img
=
resize_image
(
image
,
height
,
width
,
**
self
.
module
.
_up_kwargs
)
if
long_size
<=
crop_size
:
pad_img
=
pad_image
(
cur_img
,
self
.
module
.
mean
,
self
.
module
.
std
,
crop_size
)
outputs
=
self
.
module_inference
(
pad_img
)
outputs
=
module_inference
(
self
.
module
,
pad_img
,
self
.
flip
)
outputs
=
crop_image
(
outputs
,
0
,
height
,
0
,
width
)
else
:
if
short_size
<
crop_size
:
...
...
@@ -157,7 +147,7 @@ class MultiEvalModule(DataParallel):
# pad if needed
pad_crop_img
=
pad_image
(
crop_img
,
self
.
module
.
mean
,
self
.
module
.
std
,
crop_size
)
output
=
self
.
module_inference
(
pad_crop_img
)
output
=
module_inference
(
self
.
module
,
pad_crop_img
,
self
.
flip
)
outputs
[:,:,
h0
:
h1
,
w0
:
w1
]
+=
crop_image
(
output
,
0
,
h1
-
h0
,
0
,
w1
-
w0
)
count_norm
[:,:,
h0
:
h1
,
w0
:
w1
]
+=
1
...
...
@@ -165,21 +155,21 @@ class MultiEvalModule(DataParallel):
outputs
=
outputs
/
count_norm
outputs
=
outputs
[:,:,:
height
,:
width
]
score
=
resize_image
(
outputs
,
h
,
w
)
score
=
resize_image
(
outputs
,
h
,
w
,
**
self
.
module
.
_up_kwargs
)
scores
+=
score
return
scores
def
module_inference
(
self
,
image
):
output
=
self
.
module
.
evaluate
(
image
)
if
self
.
flip
:
fimg
=
flip_image
(
image
)
foutput
=
self
.
module
.
evaluate
(
fimg
)
output
+=
flip_image
(
foutput
)
return
output
.
exp
()
def
module_inference
(
module
,
image
,
flip
=
True
):
output
=
module
.
evaluate
(
image
)
if
flip
:
fimg
=
flip_image
(
image
)
foutput
=
module
.
evaluate
(
fimg
)
output
+=
flip_image
(
foutput
)
return
output
.
exp
()
def
resize_image
(
img
,
h
,
w
,
mode
=
'bilinear'
):
def
resize_image
(
img
,
h
,
w
,
**
up_kwargs
):
return
F
.
upsample
(
img
,
(
h
,
w
),
**
up_kwargs
)
def
pad_image
(
img
,
mean
,
std
,
crop_size
):
...
...
@@ -189,11 +179,9 @@ def pad_image(img, mean, std, crop_size):
padw
=
crop_size
-
w
if
w
<
crop_size
else
0
pad_values
=
-
np
.
array
(
mean
)
/
np
.
array
(
std
)
img_pad
=
img
.
new
().
resize_
(
b
,
c
,
h
+
padh
,
w
+
padw
)
#img_pad = F.pad(img, (0,padw,0,padh))
for
i
in
range
(
c
):
# note that pytorch pad params is in reversed orders
img_pad
[:,
i
,:,:]
=
F
.
pad
(
img
[:,
i
,:,:],
(
0
,
padw
,
0
,
padh
),
value
=
pad_values
[
i
])
img_pad
[:,
i
,:,:]
=
F
.
pad
(
img
[:,
i
,:,:],
(
0
,
padw
,
0
,
padh
),
value
=
pad_values
[
i
])
assert
(
img_pad
.
size
(
2
)
>=
crop_size
and
img_pad
.
size
(
3
)
>=
crop_size
)
return
img_pad
...
...
encoding/models/fcn.py
View file @
9bc70531
...
...
@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return
get_fcn
(
'pcontext'
,
'resnet50'
,
pretrained
)
return
get_fcn
(
'pcontext'
,
'resnet50'
,
pretrained
,
aux
=
False
)
def
get_fcn_resnet50_ade
(
pretrained
=
False
,
root
=
'~/.encoding/models'
,
**
kwargs
):
r
"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...
...
encoding/nn/customize.py
View file @
9bc70531
...
...
@@ -21,7 +21,6 @@ torch_ver = torch.__version__[:3]
__all__
=
[
'GramMatrix'
,
'SegmentationLosses'
,
'View'
,
'Sum'
,
'Mean'
,
'Normalize'
]
class
GramMatrix
(
Module
):
r
""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
...
...
experiments/recognition/main.py
View file @
9bc70531
...
...
@@ -45,8 +45,8 @@ def main():
torch
.
cuda
.
manual_seed
(
args
.
seed
)
# init dataloader
dataset
=
importlib
.
import_module
(
'dataset.'
+
args
.
dataset
)
Dataloder
=
dataset
.
Dataloder
train_loader
,
test_loader
=
Dataloder
(
args
).
getloader
()
Datalo
a
der
=
dataset
.
Datalo
a
der
train_loader
,
test_loader
=
Datalo
a
der
(
args
).
getloader
()
# init the model
models
=
importlib
.
import_module
(
'model.'
+
args
.
model
)
model
=
models
.
Net
(
args
)
...
...
experiments/segmentation/demo.py
0 → 100644
View file @
9bc70531
import
torch
import
encoding
# Get the model
model
=
encoding
.
models
.
get_model
(
'fcn_resnet50_ade'
,
pretrained
=
True
).
cuda
()
model
.
eval
()
# Prepare the image
url
=
'https://github.com/zhanghang1989/image-data/blob/master/'
+
\
'encoding/segmentation/ade20k/ADE_val_00001142.jpg?raw=true'
filename
=
'example.jpg'
img
=
encoding
.
utils
.
load_image
(
encoding
.
utils
.
download
(
url
,
filename
)).
cuda
().
unsqueeze
(
0
)
# Make prediction
output
=
model
.
evaluate
(
img
)
predict
=
torch
.
max
(
output
,
1
)[
1
].
cpu
().
numpy
()
+
1
# Get color pallete for visualization
mask
=
encoding
.
utils
.
get_mask_pallete
(
predict
,
'ade20k'
)
mask
.
save
(
'output.png'
)
experiments/segmentation/test.py
View file @
9bc70531
...
...
@@ -44,7 +44,7 @@ def test(args):
# dataloader
kwargs
=
{
'num_workers'
:
args
.
workers
,
'pin_memory'
:
True
}
\
if
args
.
cuda
else
{}
test_data
=
data
.
DataLoader
(
testset
,
batch_size
=
args
.
batch_size
,
test_data
=
data
.
DataLoader
(
testset
,
batch_size
=
args
.
test_
batch_size
,
drop_last
=
False
,
shuffle
=
False
,
collate_fn
=
test_batchify_fn
,
**
kwargs
)
# model
...
...
@@ -105,8 +105,8 @@ def test(args):
with
torch
.
no_grad
():
correct
,
labeled
,
inter
,
union
=
eval_batch
(
image
,
dst
,
evaluator
,
args
.
eval
)
if
args
.
eval
:
total_correct
+=
correct
total_label
+=
labeled
total_correct
+=
correct
.
astype
(
'int64'
)
total_label
+=
labeled
.
astype
(
'int64'
)
total_inter
+=
inter
.
astype
(
'int64'
)
total_union
+=
union
.
astype
(
'int64'
)
pixAcc
=
np
.
float64
(
1.0
)
*
total_correct
/
(
np
.
spacing
(
1
,
dtype
=
np
.
float64
)
+
total_label
)
...
...
experiments/segmentation/test_models.py
0 → 100644
View file @
9bc70531
import
importlib
import
torch
import
encoding
from
option
import
Options
from
torch.autograd
import
Variable
if
__name__
==
"__main__"
:
args
=
Options
().
parse
()
model
=
encoding
.
models
.
get_segmentation_model
(
args
.
model
,
dataset
=
args
.
dataset
,
aux
=
args
.
aux
,
se_loss
=
args
.
se_loss
,
norm_layer
=
torch
.
nn
.
BatchNorm2d
)
print
(
'Creating the model:'
)
print
(
model
)
model
.
cuda
()
x
=
Variable
(
torch
.
Tensor
(
4
,
3
,
480
,
480
)).
cuda
()
with
torch
.
no_grad
():
out
=
model
(
x
)
for
y
in
out
:
print
(
y
.
size
())
experiments/segmentation/train.py
View file @
9bc70531
...
...
@@ -60,18 +60,6 @@ class Trainer():
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
# resuming checkpoint
if
args
.
resume
is
not
None
:
if
not
os
.
path
.
isfile
(
args
.
resume
):
raise
RuntimeError
(
"=> no checkpoint found at '{}'"
.
format
(
args
.
resume
))
checkpoint
=
torch
.
load
(
args
.
resume
)
args
.
start_epoch
=
checkpoint
[
'epoch'
]
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
if
not
args
.
ft
:
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
best_pred
=
checkpoint
[
'best_pred'
]
print
(
"=> loaded checkpoint '{}' (epoch {})"
.
format
(
args
.
resume
,
checkpoint
[
'epoch'
]))
# clear start epoch if fine-tuning
if
args
.
ft
:
args
.
start_epoch
=
0
...
...
@@ -82,6 +70,21 @@ class Trainer():
if
args
.
cuda
:
self
.
model
=
DataParallelModel
(
self
.
model
).
cuda
()
self
.
criterion
=
DataParallelCriterion
(
self
.
criterion
).
cuda
()
# resuming checkpoint
if
args
.
resume
is
not
None
:
if
not
os
.
path
.
isfile
(
args
.
resume
):
raise
RuntimeError
(
"=> no checkpoint found at '{}'"
.
format
(
args
.
resume
))
checkpoint
=
torch
.
load
(
args
.
resume
)
args
.
start_epoch
=
checkpoint
[
'epoch'
]
if
args
.
cuda
:
self
.
model
.
module
.
load_state_dict
(
checkpoint
[
'state_dict'
])
else
:
self
.
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
if
not
args
.
ft
:
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
self
.
best_pred
=
checkpoint
[
'best_pred'
]
print
(
"=> loaded checkpoint '{}' (epoch {})"
.
format
(
args
.
resume
,
checkpoint
[
'epoch'
]))
# lr scheduler
self
.
scheduler
=
utils
.
LR_Scheduler
(
args
,
len
(
self
.
trainloader
))
self
.
best_pred
=
0.0
...
...
scripts/prepare_pcontext.py
View file @
9bc70531
...
...
@@ -32,6 +32,17 @@ def download_ade(path, overwrite=False):
else
:
shutil
.
move
(
filename
,
os
.
path
.
join
(
path
,
'VOCdevkit/VOC2010/'
+
os
.
path
.
basename
(
filename
)))
def
install_pcontext_api
():
repo_url
=
"https://github.com/zhanghang1989/detail-api"
os
.
system
(
"git clone "
+
repo_url
)
os
.
system
(
"cd detail-api/PythonAPI/ && python setup.py install"
)
shutil
.
rmtree
(
'detail-api'
)
try
:
import
detail
except
Exception
:
print
(
"Installing PASCAL Context API failed, please install it manually %s"
%
(
repo_url
))
if
__name__
==
'__main__'
:
args
=
parse_args
()
mkdir
(
os
.
path
.
expanduser
(
'~/.encoding/data'
))
...
...
@@ -42,3 +53,4 @@ if __name__ == '__main__':
os
.
symlink
(
args
.
download_dir
,
_TARGET_DIR
)
else
:
download_ade
(
_TARGET_DIR
,
overwrite
=
False
)
install_pcontext_api
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment