Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b8428a7f
Commit
b8428a7f
authored
Jul 15, 2022
by
peng xu
Browse files
Merge branch 'main' into beam_search
parents
e5034150
3f4e71df
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1414 additions
and
2 deletions
+1414
-2
tasks/vision/segmentation/metrics.py
tasks/vision/segmentation/metrics.py
+594
-0
tasks/vision/segmentation/seg_heads.py
tasks/vision/segmentation/seg_heads.py
+140
-0
tasks/vision/segmentation/seg_models.py
tasks/vision/segmentation/seg_models.py
+92
-0
tasks/vision/segmentation/transforms.py
tasks/vision/segmentation/transforms.py
+433
-0
tasks/vision/segmentation/utils.py
tasks/vision/segmentation/utils.py
+85
-0
tools/merge_datasets.py
tools/merge_datasets.py
+66
-0
tools/preprocess_data.py
tools/preprocess_data.py
+4
-2
No files found.
tasks/vision/segmentation/metrics.py
0 → 100644
View file @
b8428a7f
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
eps
=
1e-6
def
_binarize
(
y_data
,
threshold
):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data
[
y_data
<
threshold
]
=
0.0
y_data
[
y_data
>=
threshold
]
=
1.0
return
y_data
def
_argmax
(
y_data
,
dim
):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return
torch
.
argmax
(
y_data
,
dim
).
int
()
def
_get_tp
(
y_pred
,
y_true
):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return
torch
.
sum
(
y_true
*
y_pred
).
float
()
def
_get_fp
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return
torch
.
sum
((
1
-
y_true
)
*
y_pred
).
float
()
def
_get_tn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return
torch
.
sum
((
1
-
y_true
)
*
(
1
-
y_pred
)).
float
()
def
_get_fn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return
torch
.
sum
(
y_true
*
(
1
-
y_pred
)).
float
()
def
_get_weights
(
y_true
,
nb_ch
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size
,
img_rows
,
img_cols
=
y_true
.
shape
pixels
=
batch_size
*
img_rows
*
img_cols
weights
=
[
torch
.
sum
(
y_true
==
ch
).
item
()
/
pixels
for
ch
in
range
(
nb_ch
)]
return
weights
class
CFMatrix
(
object
):
def
__init__
(
self
,
des
=
None
):
self
.
des
=
des
def
__repr__
(
self
):
return
"ConfusionMatrix"
def
__call__
(
self
,
y_pred
,
y_true
,
ignore_index
,
threshold
=
0.5
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size
,
img_rows
,
img_cols
=
y_pred
.
shape
chs
=
ignore_index
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
[
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
]
performs
=
None
else
:
performs
=
torch
.
zeros
(
chs
,
4
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_false_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_false_ch
[
torch
.
logical_and
((
y_true
!=
ch
),
(
y_true
!=
ignore_index
))]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
torch
.
sum
(
y_false_ch
*
y_pred_ch
).
float
()
nb_tn
=
torch
.
sum
(
y_false_ch
*
(
1
-
y_pred_ch
)).
float
()
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
),
:]
=
torch
.
FloatTensor
([
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
])
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
OAAcc
(
object
):
def
__init__
(
self
,
des
=
"Overall Accuracy"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"OAcc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
nb_tp_tn
=
torch
.
sum
(
y_true
==
y_pred
).
float
()
mperforms
=
nb_tp_tn
/
(
batch_size
*
img_rows
*
img_cols
)
performs
=
None
return
mperforms
,
performs
class
Precision
(
object
):
def
__init__
(
self
,
des
=
"Precision"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Prec"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Recall
(
object
):
def
__init__
(
self
,
des
=
"Recall"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Reca"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
F1Score
(
object
):
def
__init__
(
self
,
des
=
"F1Score"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"F1Sc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
2
*
_precision
*
_recall
/
(
_precision
+
_recall
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
[
int
(
ch
)]
=
2
*
_precision
*
\
_recall
/
(
_precision
+
_recall
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Kappa
(
object
):
def
__init__
(
self
,
des
=
"Kappa"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Kapp"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
mperforms
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_tn
=
_get_tn
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
performs
[
int
(
ch
)]
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Jaccard
(
object
):
def
__init__
(
self
,
des
=
"Jaccard"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Jacc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
_intersec
=
torch
.
sum
(
y_true
*
y_pred
).
float
()
_sum
=
torch
.
sum
(
y_true
+
y_pred
).
float
()
mperforms
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
_intersec
=
torch
.
sum
(
y_true_ch
*
y_pred_ch
).
float
()
_sum
=
torch
.
sum
(
y_true_ch
+
y_pred_ch
).
float
()
performs
[
int
(
ch
)]
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
MSE
(
object
):
def
__init__
(
self
,
des
=
"Mean Square Error"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"MSE"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
return
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
class
PSNR
(
object
):
def
__init__
(
self
,
des
=
"Peak Signal to Noise Ratio"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"PSNR"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
mse
=
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
return
10
*
torch
.
log10
(
1
/
mse
)
class
SSIM
(
object
):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def
__init__
(
self
,
des
=
"structural similarity index"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"SSIM"
def
gaussian
(
self
,
w_size
,
sigma
):
gauss
=
torch
.
Tensor
([
math
.
exp
(
-
(
x
-
w_size
//
2
)
**
2
/
float
(
2
*
sigma
**
2
))
for
x
in
range
(
w_size
)])
return
gauss
/
gauss
.
sum
()
def
create_window
(
self
,
w_size
,
channel
=
1
):
_1D_window
=
self
.
gaussian
(
w_size
,
1.5
).
unsqueeze
(
1
)
_2D_window
=
_1D_window
.
mm
(
_1D_window
.
t
()).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
window
=
_2D_window
.
expand
(
channel
,
1
,
w_size
,
w_size
).
contiguous
()
return
window
def
__call__
(
self
,
y_pred
,
y_true
,
w_size
=
11
,
size_average
=
True
,
full
=
False
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if
torch
.
max
(
y_pred
)
>
128
:
max_val
=
255
else
:
max_val
=
1
if
torch
.
min
(
y_pred
)
<
-
0.5
:
min_val
=
-
1
else
:
min_val
=
0
L
=
max_val
-
min_val
padd
=
0
(
_
,
channel
,
height
,
width
)
=
y_pred
.
size
()
window
=
self
.
create_window
(
w_size
,
channel
=
channel
).
to
(
y_pred
.
device
)
mu1
=
F
.
conv2d
(
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
mu2
=
F
.
conv2d
(
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
y_pred
*
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
y_true
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu2_sq
sigma12
=
F
.
conv2d
(
y_pred
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_mu2
C1
=
(
0.01
*
L
)
**
2
C2
=
(
0.03
*
L
)
**
2
v1
=
2.0
*
sigma12
+
C2
v2
=
sigma1_sq
+
sigma2_sq
+
C2
cs
=
torch
.
mean
(
v1
/
v2
)
# contrast sensitivity
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
v1
)
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
v2
)
if
size_average
:
ret
=
ssim_map
.
mean
()
else
:
ret
=
ssim_map
.
mean
(
1
).
mean
(
1
).
mean
(
1
)
if
full
:
return
ret
,
cs
return
ret
class
AE
(
object
):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def
__init__
(
self
,
des
=
'average Angular Error'
):
self
.
des
=
des
def
__repr__
(
self
):
return
"AE"
def
__call__
(
self
,
y_pred
,
y_true
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP
=
torch
.
sum
(
y_pred
*
y_true
,
dim
=
1
)
Norm_pred
=
torch
.
sqrt
(
torch
.
sum
(
y_pred
*
y_pred
,
dim
=
1
))
Norm_true
=
torch
.
sqrt
(
torch
.
sum
(
y_true
*
y_true
,
dim
=
1
))
ae
=
180
/
math
.
pi
*
torch
.
acos
(
dotP
/
(
Norm_pred
*
Norm_true
+
eps
))
return
ae
.
mean
(
1
).
mean
(
1
)
if
__name__
==
"__main__"
:
for
ch
in
[
3
,
1
]:
batch_size
,
img_row
,
img_col
=
1
,
224
,
224
y_true
=
torch
.
rand
(
batch_size
,
ch
,
img_row
,
img_col
)
noise
=
torch
.
zeros
(
y_true
.
size
()).
data
.
normal_
(
0
,
std
=
0.1
)
y_pred
=
y_true
+
noise
for
cuda
in
[
False
,
True
]:
if
cuda
:
y_pred
=
y_pred
.
cuda
()
y_true
=
y_true
.
cuda
()
print
(
'#'
*
20
,
'Cuda : {} ; size : {}'
.
format
(
cuda
,
y_true
.
size
()))
########### similarity metrics
metric
=
MSE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
PSNR
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
SSIM
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
LPIPS
(
cuda
)
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
AE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
########### accuracy metrics
metric
=
OAAcc
()
maccu
,
accu
=
metric
(
y_pred
,
y_true
)
print
(
'mAccu:'
,
maccu
,
'Accu'
,
accu
)
metric
=
Precision
()
mprec
,
prec
=
metric
(
y_pred
,
y_true
)
print
(
'mPrec:'
,
mprec
,
'Prec'
,
prec
)
metric
=
Recall
()
mreca
,
reca
=
metric
(
y_pred
,
y_true
)
print
(
'mReca:'
,
mreca
,
'Reca'
,
reca
)
metric
=
F1Score
()
mf1sc
,
f1sc
=
metric
(
y_pred
,
y_true
)
print
(
'mF1sc:'
,
mf1sc
,
'F1sc'
,
f1sc
)
metric
=
Kappa
()
mkapp
,
kapp
=
metric
(
y_pred
,
y_true
)
print
(
'mKapp:'
,
mkapp
,
'Kapp'
,
kapp
)
metric
=
Jaccard
()
mjacc
,
jacc
=
metric
(
y_pred
,
y_true
)
print
(
'mJacc:'
,
mjacc
,
'Jacc'
,
jacc
)
tasks/vision/segmentation/seg_heads.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.utils
import
resize
class
SetrSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
SetrSegmentationHead
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
num_classes
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
patch_dim
=
args
.
patch_dim
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
self
.
conv_0
=
torch
.
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
1
,
1
,
bias
=
False
)
self
.
norm_0
=
apex
.
parallel
.
SyncBatchNorm
(
hidden_size
)
self
.
conv_1
=
torch
.
nn
.
Conv2d
(
hidden_size
,
num_classes
,
1
,
1
)
def
to_2D
(
self
,
x
):
n
,
hw
,
c
=
x
.
shape
h
=
self
.
img_h
//
self
.
patch_dim
w
=
self
.
img_w
//
self
.
patch_dim
assert
(
hw
==
h
*
w
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n
,
c
,
h
,
w
)
return
x
def
forward
(
self
,
hidden_states
):
# [b c h w]
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
to_2D
(
hidden_states
)
hidden_states
=
self
.
conv_0
(
hidden_states
)
hidden_states
=
self
.
norm_0
(
hidden_states
)
hidden_states
=
torch
.
tanh
(
hidden_states
)
hidden_states
=
self
.
conv_1
(
hidden_states
)
# [b c h w]
result
=
F
.
interpolate
(
hidden_states
,
size
=
(
self
.
img_h
,
self
.
img_w
),
mode
=
'bilinear'
)
return
result
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
SegformerSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
feature_strides
,
in_channels
,
embedding_dim
,
dropout_ratio
):
super
(
SegformerSegmentationHead
,
self
).
__init__
()
assert
len
(
feature_strides
)
==
len
(
in_channels
)
assert
min
(
feature_strides
)
==
feature_strides
[
0
]
args
=
get_args
()
self
.
feature_strides
=
feature_strides
self
.
in_channels
=
in_channels
self
.
embedding_dim
=
embedding_dim
self
.
num_classes
=
args
.
num_classes
self
.
dropout_ratio
=
dropout_ratio
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
\
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
self
.
dropout_ratio
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
num_classes
,
kernel_size
=
1
)
def
forward
(
self
,
inputs
):
c1
,
c2
,
c3
,
c4
=
inputs
############## MLP decoder on C1-C4 ###########
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
self
.
conv_fuse
(
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
))
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
return
x
tasks/vision/segmentation/seg_models.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3
,
mit_b5
from
tasks.vision.segmentation.seg_heads
import
SetrSegmentationHead
,
SegformerSegmentationHead
class
SetrSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SetrSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
assert
post_process
&
pre_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
class_token
=
False
,
post_layer_norm
=
False
,
drop_path_rate
=
0.1
)
self
.
head
=
SetrSegmentationHead
(
self
.
hidden_size
,
self
.
num_classes
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
result_final
=
self
.
head
(
hidden_states
)
return
result_final
class
SegformerSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SegformerSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
mit_b5
()
self
.
head
=
SegformerSegmentationHead
(
feature_strides
=
[
4
,
8
,
16
,
32
],
in_channels
=
[
64
,
128
,
320
,
512
],
embedding_dim
=
768
,
dropout_ratio
=
0.1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
tasks/vision/segmentation/transforms.py
0 → 100644
View file @
b8428a7f
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
,
ImageEnhance
import
torchvision.transforms
as
torch_tr
def
_is_pil_image
(
img
):
return
isinstance
(
img
,
Image
.
Image
)
class
PhotoMetricDistortion
(
object
):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def
__init__
(
self
,
brightness_delta
=
32
,
contrast_range
=
(
0.5
,
1.5
),
saturation_range
=
(
0.5
,
1.5
),
hue_delta
=
18
):
self
.
brightness_delta
=
brightness_delta
self
.
contrast_lower
,
self
.
contrast_upper
=
contrast_range
self
.
saturation_lower
,
self
.
saturation_upper
=
saturation_range
self
.
hue_delta
=
hue_delta
def
convert
(
self
,
img
,
alpha
=
1
,
beta
=
0
):
"""Multiple with alpha and add beat with clip."""
img
=
img
.
astype
(
np
.
float32
)
*
alpha
+
beta
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
.
astype
(
np
.
uint8
)
def
brightness
(
self
,
img
):
"""Brightness distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
beta
=
random
.
uniform
(
-
self
.
brightness_delta
,
self
.
brightness_delta
))
return
img
def
contrast
(
self
,
img
):
"""Contrast distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
))
return
img
def
saturation
(
self
,
img
):
"""Saturation distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
1
]
=
self
.
convert
(
img
[:,
:,
1
],
alpha
=
random
.
uniform
(
self
.
saturation_lower
,
self
.
saturation_upper
))
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
hue
(
self
,
img
):
"""Hue distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
0
]
=
(
img
[:,
:,
0
].
astype
(
int
)
+
random
.
randint
(
-
self
.
hue_delta
,
self
.
hue_delta
))
%
180
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
__call__
(
self
,
img
):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img
=
np
.
array
(
img
)
# random brightness
img
=
self
.
brightness
(
img
)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode
=
random
.
randint
(
0
,
1
)
if
mode
==
1
:
img
=
self
.
contrast
(
img
)
# random saturation
img
=
self
.
saturation
(
img
)
# random hue
img
=
self
.
hue
(
img
)
# random contrast
if
mode
==
0
:
img
=
self
.
contrast
(
img
)
img
=
Image
.
fromarray
(
img
.
astype
(
np
.
uint8
)).
convert
(
'RGB'
)
return
img
class
RandomCrop
(
object
):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def
__init__
(
self
,
crop_size
):
args
=
get_args
()
self
.
size
=
crop_size
self
.
cat_max_ratio
=
0.75
self
.
ignore_index
=
args
.
ignore_index
self
.
pad_color
=
(
0
,
0
,
0
)
def
get_crop_bbox
(
self
,
img
):
"""Randomly get a crop bounding box."""
img_w
,
img_h
=
img
.
size
target_h
,
target_w
=
self
.
size
#[H W]
margin_h
=
max
(
img_h
-
target_h
,
0
)
margin_w
=
max
(
img_w
-
target_w
,
0
)
offset_h
=
random
.
randint
(
0
,
margin_h
)
offset_w
=
random
.
randint
(
0
,
margin_w
)
crop_y1
,
crop_y2
=
offset_h
,
offset_h
+
target_h
crop_x1
,
crop_x2
=
offset_w
,
offset_w
+
target_w
return
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
def
crop
(
self
,
img
,
crop_bbox
):
"""Crop from ``img``"""
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
=
crop_bbox
img
=
img
.
crop
((
crop_x1
,
crop_y1
,
crop_x2
,
crop_y2
))
return
img
@
staticmethod
def
crop_in_image
(
target_w
,
target_h
,
w
,
h
,
img
,
mask
):
if
w
==
target_w
:
x1
=
0
else
:
x1
=
random
.
randint
(
0
,
w
-
target_w
)
if
h
==
target_h
:
y1
=
0
else
:
y1
=
random
.
randint
(
0
,
h
-
target_h
)
return
[
img
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
)),
mask
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
))]
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
target_h
,
target_w
=
self
.
size
# ASSUME H, W
if
w
==
target_w
and
h
==
target_h
:
return
img
,
mask
# Pad image if image < crop
if
target_h
>
h
:
pad_h
=
(
target_h
-
h
)
//
2
+
1
else
:
pad_h
=
0
if
target_w
>
w
:
pad_w
=
(
target_w
-
w
)
//
2
+
1
else
:
pad_w
=
0
border
=
(
pad_w
,
pad_h
,
pad_w
,
pad_h
)
if
pad_h
or
pad_w
:
img
=
ImageOps
.
expand
(
img
,
border
=
border
,
fill
=
(
0
,
0
,
0
))
mask
=
ImageOps
.
expand
(
mask
,
border
=
border
,
fill
=
self
.
ignore_index
)
w
,
h
=
img
.
size
crop_bbox
=
self
.
get_crop_bbox
(
img
)
if
self
.
cat_max_ratio
<
1.
:
# Repeat 10 times
for
_
in
range
(
10
):
seg_temp
=
self
.
crop
(
mask
,
crop_bbox
)
labels
,
cnt
=
np
.
unique
(
seg_temp
,
return_counts
=
True
)
cnt
=
cnt
[
labels
!=
self
.
ignore_index
]
if
len
(
cnt
)
>
1
and
np
.
max
(
cnt
)
/
np
.
sum
(
cnt
)
<
self
.
cat_max_ratio
:
break
crop_bbox
=
self
.
get_crop_bbox
(
img
)
# crop the image
img
=
self
.
crop
(
img
,
crop_bbox
)
# crop semantic seg
mask
=
self
.
crop
(
mask
,
crop_bbox
)
assert
(
img
.
size
[
0
]
==
self
.
size
[
1
]
and
img
.
size
[
1
]
==
self
.
size
[
0
])
return
img
,
mask
class
RandomSizeAndCrop
(
object
):
def
__init__
(
self
,
crop_size
,
scale_min
=
0.5
,
scale_max
=
2.0
):
self
.
crop
=
RandomCrop
(
crop_size
)
self
.
scale_min
=
scale_min
self
.
scale_max
=
scale_max
def
__call__
(
self
,
img
,
mask
):
scale_amt
=
random
.
uniform
(
self
.
scale_min
,
self
.
scale_max
)
w
,
h
=
[
int
(
i
*
scale_amt
)
for
i
in
img
.
size
]
resized_img
=
img
.
resize
((
w
,
h
),
Image
.
BICUBIC
)
resized_mask
=
mask
.
resize
((
w
,
h
),
Image
.
NEAREST
)
img
,
mask
=
self
.
crop
(
resized_img
,
resized_mask
)
return
img
,
mask
class
RandomHorizontallyFlip
(
object
):
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
),
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
def
adjust_brightness
(
img
,
brightness_factor
):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
return
img
def
adjust_contrast
(
img
,
contrast_factor
):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
return
img
def
adjust_saturation
(
img
,
saturation_factor
):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
return
img
def
adjust_hue
(
img
,
hue_factor
):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
'hue_factor is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
input_mode
=
img
.
mode
if
input_mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
img
h
,
s
,
v
=
img
.
convert
(
'HSV'
).
split
()
np_h
=
np
.
array
(
h
,
dtype
=
np
.
uint8
)
# uint8 addition take cares of rotation across boundaries
with
np
.
errstate
(
over
=
'ignore'
):
np_h
+=
np
.
uint8
(
hue_factor
*
255
)
h
=
Image
.
fromarray
(
np_h
,
'L'
)
img
=
Image
.
merge
(
'HSV'
,
(
h
,
s
,
v
)).
convert
(
input_mode
)
return
img
class
ColorJitter
(
object
):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
self
.
brightness
=
brightness
self
.
contrast
=
contrast
self
.
saturation
=
saturation
self
.
hue
=
hue
@
staticmethod
def
get_params
(
brightness
,
contrast
,
saturation
,
hue
):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms
=
[]
if
brightness
>
0
:
brightness_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
brightness
),
1
+
brightness
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_brightness
(
img
,
brightness_factor
)))
if
contrast
>
0
:
contrast_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
contrast
),
1
+
contrast
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_contrast
(
img
,
contrast_factor
)))
if
saturation
>
0
:
saturation_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
saturation
),
1
+
saturation
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_saturation
(
img
,
saturation_factor
)))
if
hue
>
0
:
hue_factor
=
np
.
random
.
uniform
(
-
hue
,
hue
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_hue
(
img
,
hue_factor
)))
np
.
random
.
shuffle
(
transforms
)
transform
=
torch_tr
.
Compose
(
transforms
)
return
transform
def
__call__
(
self
,
img
):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform
=
self
.
get_params
(
self
.
brightness
,
self
.
contrast
,
self
.
saturation
,
self
.
hue
)
return
transform
(
img
)
tasks/vision/segmentation/utils.py
0 → 100644
View file @
b8428a7f
import
math
import
torch
import
numpy
as
np
from
megatron
import
get_args
def
slidingcrops
(
img
,
mask
):
# img: [b c h w]
# mask: [b h w]
args
=
get_args
()
assert
args
.
img_h
==
args
.
img_w
crop_size
=
args
.
img_h
stride
=
args
.
seg_stride
ignore_index
=
args
.
ignore_index
n
,
c
,
h
,
w
=
img
.
shape
assert
h
>=
crop_size
assert
w
>=
crop_size
long_size
=
max
(
h
,
w
)
img_slices
,
mask_slices
,
slices_info
=
[],
[],
[]
if
long_size
>
crop_size
:
assert
stride
<=
crop_size
h_step_num
=
int
(
math
.
ceil
((
h
-
crop_size
)
/
float
(
stride
)))
+
1
w_step_num
=
int
(
math
.
ceil
((
w
-
crop_size
)
/
float
(
stride
)))
+
1
for
yy
in
range
(
h_step_num
):
for
xx
in
range
(
w_step_num
):
sy
,
sx
=
yy
*
stride
,
xx
*
stride
ey
,
ex
=
sy
+
crop_size
,
sx
+
crop_size
img_sub
=
img
[:,
:,
sy
:
ey
,
sx
:
ex
]
mask_sub
=
mask
[:,
sy
:
ey
,
sx
:
ex
]
# padding
sub_h
,
sub_w
=
img_sub
.
shape
[
2
:]
pad_h
=
max
(
crop_size
-
sub_h
,
0
)
pad_w
=
max
(
crop_size
-
sub_w
,
0
)
img_sub
=
torch
.
nn
.
functional
.
pad
(
img_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
),
value
=
ignore_index
)
mask_sub
=
torch
.
nn
.
functional
.
pad
(
mask_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
))
img_slices
.
append
(
img_sub
)
mask_slices
.
append
(
mask_sub
)
slices_info
.
append
([
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
])
return
torch
.
cat
(
img_slices
),
torch
.
cat
(
mask_slices
),
slices_info
,
(
h
,
w
)
else
:
return
img
,
mask
,
[[
0
,
h
,
0
,
w
,
h
,
w
]],
(
h
,
w
)
def
slidingjoins
(
preds
,
probs
,
labels
,
slices_info
,
img_size
):
args
=
get_args
()
num_slices
=
len
(
slices_info
)
if
num_slices
==
1
:
return
preds
,
labels
h
,
w
=
img_size
split_size
=
args
.
micro_batch_size
preds_split
=
torch
.
split
(
preds
,
split_size
)
probs_split
=
torch
.
split
(
probs
,
split_size
)
labels_split
=
torch
.
split
(
labels
,
split_size
)
assert
(
len
(
preds_split
)
==
num_slices
)
total_max_probs
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
float
,
device
=
'cuda'
)
total_preds
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
total_labels
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
for
i
in
range
(
num_slices
):
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
=
slices_info
[
i
]
assert
sy
+
sub_h
<=
h
assert
sx
+
sub_w
<=
w
curr_max_probs
=
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
curr_preds
=
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
local_max_probs
=
probs_split
[
i
][:,
:
sub_h
,
:
sub_w
]
local_preds
=
preds_split
[
i
][:,
:
sub_h
,
:
sub_w
]
result_max_probs
=
torch
.
maximum
(
curr_max_probs
,
local_max_probs
)
result_preds
=
torch
.
where
(
curr_max_probs
>=
local_max_probs
,
curr_preds
,
local_preds
)
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_max_probs
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_preds
total_labels
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
labels_split
[
i
][
0
,
:
sub_h
,
:
sub_w
]
return
total_preds
,
total_labels
tools/merge_datasets.py
0 → 100644
View file @
b8428a7f
import
os
import
sys
import
json
import
argparse
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
from
megatron.data
import
indexed_dataset
def
main
(
args
):
prefixes
=
set
()
for
basename
in
os
.
listdir
(
args
.
input
):
prefix
,
ext
=
os
.
path
.
splitext
(
basename
)
if
prefix
in
prefixes
:
continue
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
input
,
basename
)):
continue
ext_pair
=
'.bin'
if
ext
==
'.idx'
else
'.idx'
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
input
,
prefix
)
+
ext_pair
),
\
f
'ERROR:
{
ext_pair
}
file not provided for
{
os
.
path
.
join
(
args
.
input
,
prefix
)
}
'
prefixes
.
add
(
prefix
)
builder
=
None
for
prefix
in
sorted
(
prefixes
):
if
builder
is
None
:
dataset
=
indexed_dataset
.
make_dataset
(
os
.
path
.
join
(
args
.
input
,
prefix
),
'infer'
)
if
isinstance
(
dataset
,
indexed_dataset
.
MMapIndexedDataset
):
builder
=
indexed_dataset
.
MMapIndexedDatasetBuilder
(
args
.
output_prefix
+
'.bin'
,
dtype
=
dataset
.
_index
.
dtype
)
else
:
builder
=
indexed_dataset
.
IndexedDatasetBuilder
(
args
.
output_prefix
+
'.bin'
)
del
dataset
builder
.
merge_file_
(
os
.
path
.
join
(
args
.
input
,
prefix
))
builder
.
finalize
(
args
.
output_prefix
+
'.idx'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
group
=
parser
.
add_argument_group
(
title
=
'input data'
)
group
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
'Path to directory containing all document files to merge'
)
group
=
parser
.
add_argument_group
(
title
=
'output data'
)
group
.
add_argument
(
'--output-prefix'
,
type
=
str
,
required
=
True
,
help
=
'Path to binary output file without suffix'
)
args
=
parser
.
parse_args
()
assert
os
.
path
.
isdir
(
args
.
input
),
\
f
'ERROR:
{
args
.
input
}
is not a directory or does not exist'
assert
os
.
path
.
isdir
(
os
.
path
.
dirname
(
args
.
output_prefix
)),
\
f
'ERROR:
{
os
.
path
.
dirname
(
args
.
output_prefix
)
}
is not a directory or does not exist'
main
(
args
)
tools/preprocess_data.py
View file @
b8428a7f
...
...
@@ -122,8 +122,10 @@ def get_args():
choices
=
[
'lazy'
,
'cached'
,
'mmap'
])
group
=
parser
.
add_argument_group
(
title
=
'runtime'
)
group
.
add_argument
(
'--workers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--workers'
,
type
=
int
,
required
=
True
,
help
=
'Number of worker processes to launch'
)
group
.
add_argument
(
'--chunk-size'
,
type
=
int
,
required
=
True
,
help
=
'Chunk size assigned to each worker process'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Interval between progress updates'
)
args
=
parser
.
parse_args
()
...
...
@@ -154,7 +156,7 @@ def main():
encoder
=
Encoder
(
args
)
tokenizer
=
build_tokenizer
(
args
)
pool
=
multiprocessing
.
Pool
(
args
.
workers
,
initializer
=
encoder
.
initializer
)
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
args
.
chunk_size
)
#encoded_docs = map(encoder.encode, fin)
level
=
"document"
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment