Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
CenterFace_pytorch
Commits
69f4ba48
Commit
69f4ba48
authored
Nov 07, 2023
by
chenych
Browse files
update readme
parent
e0491117
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
235 additions
and
210 deletions
+235
-210
README.md
README.md
+16
-21
src/lib/datasets/dataset/centerface_hp.py
src/lib/datasets/dataset/centerface_hp.py
+8
-5
src/lib/models/decode.py
src/lib/models/decode.py
+39
-39
src/lib/models/losses.py
src/lib/models/losses.py
+158
-138
src/lib/trains/base_trainer.py
src/lib/trains/base_trainer.py
+1
-1
src/lib/trains/multi_pose.py
src/lib/trains/multi_pose.py
+13
-6
No files found.
README.md
View file @
69f4ba48
...
...
@@ -4,8 +4,9 @@
## 模型结构
CenterFace是一种人脸检测算法,采用了轻量级网络mobileNetV2作为主干网络,结合特征金字塔网络(FPN)实现anchor free的人脸检测。

<div
align=
center
>
<img
src=
"./Architecture of the CenterFace.png"
/>
</div>
## 算法原理
CenterFace模型是一种基于单阶段人脸检测算法,作者借鉴了CenterNet的思想,将人脸检测转换为标准点问题,根据人脸中心点来回归人脸框的大小和五个标志点。
...
...
@@ -53,7 +54,9 @@ pip3 install -r requirements.txt
WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html

<div
align=
center
>
<img
src=
"./datasets.png"
/>
</div>
下载图片红框中三个数据并解压,也可直接点击下面链接直接下载:
...
...
@@ -99,25 +102,15 @@ WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
│ └── 61--Street_Battle
```
<<<<<<< HEAD
2.
如果是使用WIDER_train、WIDER_val数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json, val_wider_face.json重命名为val_face.json即可,无需进行标注文件格式转换;
反之,需要将训练图片/验证图片对应的人脸关键点标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
=======
2.
如果是使用WIDER_train数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json即可,无需进行标注文件格式转换;反之,需要将训练图片对应的人脸关键点标注信息文件(xxxx.txt),放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
反之,需要将训练图片/验证图片对应的人脸标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
```
├── annotations
│ ├── train
<<<<<<< HEAD
│ ├── train.txt
│ ├── val
│ ├── val.txt
=======
│ ├── xxx.txt
│ ├── val
│ ├── xxx.txt
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
```
特别地,标注信息的格式为:
...
...
@@ -128,7 +121,8 @@ x, y, w, h, left_eye_x, left_eye_y, flag, right_eye_x, right_eye_y, flag, nose_x
```
举个例子:
train_keypoints_widerface.txt是wider_face训练数据集的标注信息
./datasets/annotations/train/train.txt是wider_face训练数据集的标注信息
```
# 0--Parade/0_Parade_marchingband_1_849.jpg
449 330 122 149 488.906 373.643 0.0 542.089 376.442 0.0 515.031 412.83 0.0 485.174 425.893 0.0 538.357 431.491 0.0 0.82
...
...
@@ -142,11 +136,7 @@ cd ./datasets
python gen_data.py
```
<<<<<<< HEAD
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json、val_face.json
=======
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
## 训练
...
...
@@ -166,18 +156,23 @@ bash train_multi.sh
## 推理
#### 单卡推理
```
cd ./src
python test_wider_face.py
```
## result

<div
align=
center
>
<img
src=
"./draw_img.jpg"
/>
</div>
### 精度
WIDER_FACE验证集上的测试结果如下
| Method | Easy(
p
) | Medium(
p
) | Hard(
p
)|
| Method | Easy(
AP
) | Medium(
AP
) | Hard(
AP
)|
|:--------:| :--------:| :---------:| :------:|
| ours(one scale) | 0.9264 | 0.9133 | 0.7479 |
| original | 0.922 | 0.911 | 0.782|
...
...
src/lib/datasets/dataset/centerface_hp.py
View file @
69f4ba48
...
...
@@ -2,13 +2,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
pycocotools.coco
as
coco
from
pycocotools.cocoeval
import
COCOeval
import
numpy
as
np
import
json
import
os
from
pycocotools.cocoeval
import
COCOeval
import
torch.utils.data
as
data
import
numpy
as
np
import
pycocotools.coco
as
coco
class
FACEHP
(
data
.
Dataset
):
...
...
@@ -19,7 +20,8 @@ class FACEHP(data.Dataset):
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
std
=
np
.
array
([
0.28863828
,
0.27408164
,
0.27809835
],
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
flip_idx
=
[[
0
,
1
],
[
3
,
4
]]
# 翻转的关键点在关键点矩阵中的索引
# 翻转的关键点在关键点矩阵中的索引
flip_idx
=
[[
0
,
1
],
[
3
,
4
]]
def
__init__
(
self
,
opt
,
split
):
super
(
FACEHP
,
self
).
__init__
()
...
...
@@ -32,7 +34,8 @@ class FACEHP(data.Dataset):
self
.
acc_idxs
=
[
1
,
2
,
3
,
4
]
self
.
data_dir
=
opt
.
data_dir
_ann_name
=
{
'train'
:
'train'
,
'val'
:
'val'
}
self
.
img_dir
=
os
.
path
.
join
(
self
.
data_dir
,
f
'images/
{
_ann_name
[
split
]
}
'
)
# 训练图片所在地址
# 图片所在地址
self
.
img_dir
=
os
.
path
.
join
(
self
.
data_dir
,
f
'images/
{
_ann_name
[
split
]
}
'
)
print
(
'===>'
,
self
.
img_dir
)
if
split
==
'val'
:
self
.
annot_path
=
os
.
path
.
join
(
...
...
src/lib/models/decode.py
View file @
69f4ba48
...
...
@@ -19,33 +19,33 @@ def _left_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
shape
=
heat
.
shape
shape
=
heat
.
shape
heat
=
heat
.
reshape
(
-
1
,
heat
.
shape
[
3
])
heat
=
heat
.
transpose
(
1
,
0
).
contiguous
()
ret
=
heat
.
clone
()
for
i
in
range
(
1
,
heat
.
shape
[
0
]):
inds
=
(
heat
[
i
]
>=
heat
[
i
-
1
])
ret
[
i
]
+=
ret
[
i
-
1
]
*
inds
.
float
()
return
(
ret
-
heat
).
transpose
(
1
,
0
).
reshape
(
shape
)
return
(
ret
-
heat
).
transpose
(
1
,
0
).
reshape
(
shape
)
def
_right_aggregate
(
heat
):
'''
heat: batchsize x channels x h x w
'''
shape
=
heat
.
shape
shape
=
heat
.
shape
heat
=
heat
.
reshape
(
-
1
,
heat
.
shape
[
3
])
heat
=
heat
.
transpose
(
1
,
0
).
contiguous
()
ret
=
heat
.
clone
()
for
i
in
range
(
heat
.
shape
[
0
]
-
2
,
-
1
,
-
1
):
inds
=
(
heat
[
i
]
>=
heat
[
i
+
1
])
ret
[
i
]
+=
ret
[
i
+
1
]
*
inds
.
float
()
return
(
ret
-
heat
).
transpose
(
1
,
0
).
reshape
(
shape
)
return
(
ret
-
heat
).
transpose
(
1
,
0
).
reshape
(
shape
)
def
_top_aggregate
(
heat
):
'''
heat: batchsize x channels x h x w
'''
heat
=
heat
.
transpose
(
3
,
2
)
heat
=
heat
.
transpose
(
3
,
2
)
shape
=
heat
.
shape
heat
=
heat
.
reshape
(
-
1
,
heat
.
shape
[
3
])
heat
=
heat
.
transpose
(
1
,
0
).
contiguous
()
...
...
@@ -59,7 +59,7 @@ def _bottom_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
heat
=
heat
.
transpose
(
3
,
2
)
heat
=
heat
.
transpose
(
3
,
2
)
shape
=
heat
.
shape
heat
=
heat
.
reshape
(
-
1
,
heat
.
shape
[
3
])
heat
=
heat
.
transpose
(
1
,
0
).
contiguous
()
...
...
@@ -92,7 +92,7 @@ def _topk(scores, K=40):
'''
def
_topk_channel
(
scores
,
K
=
40
):
batch
,
cat
,
height
,
width
=
scores
.
size
()
topk_scores
,
topk_inds
=
torch
.
topk
(
scores
.
view
(
batch
,
cat
,
-
1
),
K
)
topk_inds
=
topk_inds
%
(
height
*
width
)
...
...
@@ -103,13 +103,13 @@ def _topk_channel(scores, K=40):
def
_topk
(
scores
,
K
=
40
):
batch
,
cat
,
height
,
width
=
scores
.
size
()
topk_scores
,
topk_inds
=
torch
.
topk
(
scores
.
view
(
batch
,
cat
,
-
1
),
K
)
# 前100个点
# 前100个点
topk_scores
,
topk_inds
=
torch
.
topk
(
scores
.
view
(
batch
,
cat
,
-
1
),
K
)
topk_inds
=
topk_inds
%
(
height
*
width
)
topk_ys
=
(
topk_inds
/
width
).
int
().
float
()
topk_xs
=
(
topk_inds
%
width
).
int
().
float
()
topk_score
,
topk_ind
=
torch
.
topk
(
topk_scores
.
view
(
batch
,
-
1
),
K
)
topk_clses
=
(
topk_ind
/
K
).
int
()
topk_inds
=
_gather_feat
(
...
...
@@ -121,8 +121,8 @@ def _topk(scores, K=40):
def
agnex_ct_decode
(
t_heat
,
l_heat
,
b_heat
,
r_heat
,
ct_heat
,
t_regr
=
None
,
l_regr
=
None
,
b_regr
=
None
,
r_regr
=
None
,
t_heat
,
l_heat
,
b_heat
,
r_heat
,
ct_heat
,
t_regr
=
None
,
l_regr
=
None
,
b_regr
=
None
,
r_regr
=
None
,
K
=
40
,
scores_thresh
=
0.1
,
center_thresh
=
0.1
,
aggr_weight
=
0.0
,
num_dets
=
1000
):
batch
,
cat
,
height
,
width
=
t_heat
.
size
()
...
...
@@ -134,19 +134,19 @@ def agnex_ct_decode(
r_heat = torch.sigmoid(r_heat)
ct_heat = torch.sigmoid(ct_heat)
'''
if
aggr_weight
>
0
:
if
aggr_weight
>
0
:
t_heat
=
_h_aggregate
(
t_heat
,
aggr_weight
=
aggr_weight
)
l_heat
=
_v_aggregate
(
l_heat
,
aggr_weight
=
aggr_weight
)
b_heat
=
_h_aggregate
(
b_heat
,
aggr_weight
=
aggr_weight
)
r_heat
=
_v_aggregate
(
r_heat
,
aggr_weight
=
aggr_weight
)
# perform nms on heatmaps
t_heat
=
_nms
(
t_heat
)
l_heat
=
_nms
(
l_heat
)
b_heat
=
_nms
(
b_heat
)
r_heat
=
_nms
(
r_heat
)
t_heat
[
t_heat
>
1
]
=
1
l_heat
[
l_heat
>
1
]
=
1
b_heat
[
b_heat
>
1
]
=
1
...
...
@@ -156,9 +156,9 @@ def agnex_ct_decode(
l_scores
,
l_inds
,
_
,
l_ys
,
l_xs
=
_topk
(
l_heat
,
K
=
K
)
b_scores
,
b_inds
,
_
,
b_ys
,
b_xs
=
_topk
(
b_heat
,
K
=
K
)
r_scores
,
r_inds
,
_
,
r_ys
,
r_xs
=
_topk
(
r_heat
,
K
=
K
)
ct_heat_agn
,
ct_clses
=
torch
.
max
(
ct_heat
,
dim
=
1
,
keepdim
=
True
)
# import pdb; pdb.set_trace()
t_ys
=
t_ys
.
view
(
batch
,
K
,
1
,
1
,
1
).
expand
(
batch
,
K
,
K
,
K
,
K
)
...
...
@@ -240,7 +240,7 @@ def agnex_ct_decode(
b_ys
=
b_ys
+
0.5
r_xs
=
r_xs
+
0.5
r_ys
=
r_ys
+
0.5
bboxes
=
torch
.
stack
((
l_xs
,
t_ys
,
r_xs
,
b_ys
),
dim
=
5
)
bboxes
=
bboxes
.
view
(
batch
,
-
1
,
4
)
bboxes
=
_gather_feat
(
bboxes
,
inds
)
...
...
@@ -266,14 +266,14 @@ def agnex_ct_decode(
r_ys
=
_gather_feat
(
r_ys
,
inds
).
float
()
detections
=
torch
.
cat
([
bboxes
,
scores
,
t_xs
,
t_ys
,
l_xs
,
l_ys
,
detections
=
torch
.
cat
([
bboxes
,
scores
,
t_xs
,
t_ys
,
l_xs
,
l_ys
,
b_xs
,
b_ys
,
r_xs
,
r_ys
,
clses
],
dim
=
2
)
return
detections
def
exct_decode
(
t_heat
,
l_heat
,
b_heat
,
r_heat
,
ct_heat
,
t_regr
=
None
,
l_regr
=
None
,
b_regr
=
None
,
r_regr
=
None
,
t_heat
,
l_heat
,
b_heat
,
r_heat
,
ct_heat
,
t_regr
=
None
,
l_regr
=
None
,
b_regr
=
None
,
r_regr
=
None
,
K
=
40
,
scores_thresh
=
0.1
,
center_thresh
=
0.1
,
aggr_weight
=
0.0
,
num_dets
=
1000
):
batch
,
cat
,
height
,
width
=
t_heat
.
size
()
...
...
@@ -285,18 +285,18 @@ def exct_decode(
ct_heat = torch.sigmoid(ct_heat)
'''
if
aggr_weight
>
0
:
if
aggr_weight
>
0
:
t_heat
=
_h_aggregate
(
t_heat
,
aggr_weight
=
aggr_weight
)
l_heat
=
_v_aggregate
(
l_heat
,
aggr_weight
=
aggr_weight
)
b_heat
=
_h_aggregate
(
b_heat
,
aggr_weight
=
aggr_weight
)
r_heat
=
_v_aggregate
(
r_heat
,
aggr_weight
=
aggr_weight
)
# perform nms on heatmaps
t_heat
=
_nms
(
t_heat
)
l_heat
=
_nms
(
l_heat
)
b_heat
=
_nms
(
b_heat
)
r_heat
=
_nms
(
r_heat
)
t_heat
[
t_heat
>
1
]
=
1
l_heat
[
l_heat
>
1
]
=
1
b_heat
[
b_heat
>
1
]
=
1
...
...
@@ -392,7 +392,7 @@ def exct_decode(
b_ys
=
b_ys
+
0.5
r_xs
=
r_xs
+
0.5
r_ys
=
r_ys
+
0.5
bboxes
=
torch
.
stack
((
l_xs
,
t_ys
,
r_xs
,
b_ys
),
dim
=
5
)
bboxes
=
bboxes
.
view
(
batch
,
-
1
,
4
)
bboxes
=
_gather_feat
(
bboxes
,
inds
)
...
...
@@ -418,7 +418,7 @@ def exct_decode(
r_ys
=
_gather_feat
(
r_ys
,
inds
).
float
()
detections
=
torch
.
cat
([
bboxes
,
scores
,
t_xs
,
t_ys
,
l_xs
,
l_ys
,
detections
=
torch
.
cat
([
bboxes
,
scores
,
t_xs
,
t_ys
,
l_xs
,
l_ys
,
b_xs
,
b_ys
,
r_xs
,
r_ys
,
clses
],
dim
=
2
)
...
...
@@ -429,7 +429,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat
=
_nms
(
heat
)
scores
,
inds
,
clses
,
ys
,
xs
=
_topk
(
heat
,
K
=
K
)
if
reg
is
not
None
:
reg
=
_tranpose_and_gather_feat
(
reg
,
inds
)
...
...
@@ -439,7 +439,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
else
:
xs
=
xs
.
view
(
batch
,
K
,
1
)
+
0.5
ys
=
ys
.
view
(
batch
,
K
,
1
)
+
0.5
rot
=
_tranpose_and_gather_feat
(
rot
,
inds
)
rot
=
rot
.
view
(
batch
,
K
,
8
)
depth
=
_tranpose_and_gather_feat
(
depth
,
inds
)
...
...
@@ -450,7 +450,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
scores
=
scores
.
view
(
batch
,
K
,
1
)
xs
=
xs
.
view
(
batch
,
K
,
1
)
ys
=
ys
.
view
(
batch
,
K
,
1
)
if
wh
is
not
None
:
wh
=
_tranpose_and_gather_feat
(
wh
,
inds
)
wh
=
wh
.
view
(
batch
,
K
,
2
)
...
...
@@ -459,7 +459,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
else
:
detections
=
torch
.
cat
(
[
xs
,
ys
,
scores
,
rot
,
depth
,
dim
,
clses
],
dim
=
2
)
return
detections
def
ctdet_decode
(
heat
,
wh
,
reg
=
None
,
cat_spec_wh
=
False
,
K
=
100
):
...
...
@@ -468,7 +468,7 @@ def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat
=
_nms
(
heat
)
# 3 * 3 区域的最大值滤波
scores
,
inds
,
clses
,
ys
,
xs
=
_topk
(
heat
,
K
=
K
)
if
reg
is
not
None
:
reg
=
_tranpose_and_gather_feat
(
reg
,
inds
)
...
...
@@ -487,12 +487,12 @@ def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
wh
=
wh
.
view
(
batch
,
K
,
2
)
clses
=
clses
.
view
(
batch
,
K
,
1
).
float
()
scores
=
scores
.
view
(
batch
,
K
,
1
)
bboxes
=
torch
.
cat
([
xs
-
wh
[...,
0
:
1
]
/
2
,
bboxes
=
torch
.
cat
([
xs
-
wh
[...,
0
:
1
]
/
2
,
ys
-
wh
[...,
1
:
2
]
/
2
,
xs
+
wh
[...,
0
:
1
]
/
2
,
xs
+
wh
[...,
0
:
1
]
/
2
,
ys
+
wh
[...,
1
:
2
]
/
2
],
dim
=
2
)
detections
=
torch
.
cat
([
bboxes
,
scores
,
clses
],
dim
=
2
)
return
detections
def
multi_pose_decode
(
...
...
@@ -521,9 +521,9 @@ def multi_pose_decode(
clses
=
clses
.
view
(
batch
,
K
,
1
).
float
()
scores
=
scores
.
view
(
batch
,
K
,
1
)
bboxes
=
torch
.
cat
([
xs
-
wh
[...,
0
:
1
]
/
2
,
bboxes
=
torch
.
cat
([
xs
-
wh
[...,
0
:
1
]
/
2
,
ys
-
wh
[...,
1
:
2
]
/
2
,
xs
+
wh
[...,
0
:
1
]
/
2
,
xs
+
wh
[...,
0
:
1
]
/
2
,
ys
+
wh
[...,
1
:
2
]
/
2
],
dim
=
2
)
if
hm_hp
is
not
None
:
hm_hp
=
_nms
(
hm_hp
)
# 第二次:通过关节点热力图求得关节点的中心点
...
...
@@ -541,7 +541,7 @@ def multi_pose_decode(
else
:
hm_xs
=
hm_xs
+
0.5
hm_ys
=
hm_ys
+
0.5
mask
=
(
hm_score
>
thresh
).
float
()
# 选置信度大于0.1的
hm_score
=
(
1
-
mask
)
*
-
1
+
mask
*
hm_score
hm_ys
=
(
1
-
mask
)
*
(
-
10000
)
+
mask
*
hm_ys
...
...
@@ -570,7 +570,7 @@ def multi_pose_decode(
kps
=
kps
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
batch
,
K
,
num_joints
*
2
)
detections
=
torch
.
cat
([
bboxes
,
scores
,
kps
,
clses
],
dim
=
2
)
# box:4+score:1+kpoints:10+class:1=16
return
detections
...
...
src/lib/models/losses.py
View file @
69f4ba48
...
...
@@ -15,61 +15,63 @@ import torch.nn.functional as F
def
_slow_neg_loss
(
pred
,
gt
):
'''focal loss from CornerNet'''
pos_inds
=
gt
.
eq
(
1
)
neg_inds
=
gt
.
lt
(
1
)
'''focal loss from CornerNet'''
pos_inds
=
gt
.
eq
(
1
)
neg_inds
=
gt
.
lt
(
1
)
neg_weights
=
torch
.
pow
(
1
-
gt
[
neg_inds
],
4
)
neg_weights
=
torch
.
pow
(
1
-
gt
[
neg_inds
],
4
)
loss
=
0
pos_pred
=
pred
[
pos_inds
]
neg_pred
=
pred
[
neg_inds
]
loss
=
0
pos_pred
=
pred
[
pos_inds
]
neg_pred
=
pred
[
neg_inds
]
pos_loss
=
torch
.
log
(
pos_pred
)
*
torch
.
pow
(
1
-
pos_pred
,
2
)
neg_loss
=
torch
.
log
(
1
-
neg_pred
)
*
torch
.
pow
(
neg_pred
,
2
)
*
neg_weights
pos_loss
=
torch
.
log
(
pos_pred
)
*
torch
.
pow
(
1
-
pos_pred
,
2
)
neg_loss
=
torch
.
log
(
1
-
neg_pred
)
*
torch
.
pow
(
neg_pred
,
2
)
*
neg_weights
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
neg_loss
=
neg_loss
.
sum
()
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
neg_loss
=
neg_loss
.
sum
()
if
pos_pred
.
nelement
()
==
0
:
loss
=
loss
-
neg_loss
else
:
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
if
pos_pred
.
nelement
()
==
0
:
loss
=
loss
-
neg_loss
else
:
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
def
_neg_loss
(
pred
,
gt
):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds
=
gt
.
eq
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds
=
gt
.
eq
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
neg_weights
=
torch
.
pow
(
1
-
gt
,
4
)
neg_weights
=
torch
.
pow
(
1
-
gt
,
4
)
loss
=
0
loss
=
0
pos_loss
=
torch
.
log
(
pred
)
*
torch
.
pow
(
1
-
pred
,
2
)
*
pos_inds
neg_loss
=
torch
.
log
(
1
-
pred
)
*
torch
.
pow
(
pred
,
2
)
*
\
neg_weights
*
neg_inds
pos_loss
=
torch
.
log
(
pred
)
*
torch
.
pow
(
1
-
pred
,
2
)
*
pos_inds
neg_loss
=
torch
.
log
(
1
-
pred
)
*
torch
.
pow
(
pred
,
2
)
*
neg_weights
*
neg_inds
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
neg_loss
=
neg_loss
.
sum
()
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
neg_loss
=
neg_loss
.
sum
()
if
num_pos
==
0
:
loss
=
loss
-
neg_loss
else
:
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
if
num_pos
==
0
:
loss
=
loss
-
neg_loss
else
:
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
def
_not_faster_neg_loss
(
pred
,
gt
):
pos_inds
=
gt
.
eq
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
num_pos
=
pos_inds
.
float
().
sum
()
num_pos
=
pos_inds
.
float
().
sum
()
neg_weights
=
torch
.
pow
(
1
-
gt
,
4
)
loss
=
0
...
...
@@ -80,141 +82,159 @@ def _not_faster_neg_loss(pred, gt):
if
num_pos
>
0
:
all_loss
/=
num_pos
loss
-=
all_loss
loss
-=
all_loss
return
loss
def
_slow_reg_loss
(
regr
,
gt_regr
,
mask
):
num
=
mask
.
float
().
sum
()
num
=
mask
.
float
().
sum
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
)
regr
=
regr
[
mask
]
regr
=
regr
[
mask
]
gt_regr
=
gt_regr
[
mask
]
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'sum'
)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
return
regr_loss
def
_reg_loss
(
regr
,
gt_regr
,
mask
,
wight_
=
None
):
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num
=
mask
.
float
().
sum
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
regr
=
regr
*
mask
gt_regr
=
gt_regr
*
mask
if
wight_
is
not
None
:
wight_
=
wight_
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'none'
)
regr_loss
*=
wight_
regr_loss
=
regr_loss
.
sum
()
else
:
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'sum'
)
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
return
regr_loss
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num
=
mask
.
float
().
sum
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
regr
=
regr
*
mask
gt_regr
=
gt_regr
*
mask
if
wight_
is
not
None
:
wight_
=
wight_
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'none'
)
regr_loss
*=
wight_
regr_loss
=
regr_loss
.
sum
()
else
:
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'sum'
)
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
return
regr_loss
class
FocalLoss
(
nn
.
Module
):
'''nn.Module warpper for focal loss'''
def
__init__
(
self
):
super
(
FocalLoss
,
self
).
__init__
()
self
.
neg_loss
=
_neg_loss
'''nn.Module warpper for focal loss'''
def
__init__
(
self
):
super
(
FocalLoss
,
self
).
__init__
()
self
.
neg_loss
=
_neg_loss
def
forward
(
self
,
out
,
target
):
return
self
.
neg_loss
(
out
,
target
)
def
forward
(
self
,
out
,
target
):
return
self
.
neg_loss
(
out
,
target
)
class
RegLoss
(
nn
.
Module
):
'''Regression loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def
__init__
(
self
):
super
(
RegLoss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
,
wight_
=
None
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
loss
=
_reg_loss
(
pred
,
target
,
mask
,
wight_
)
return
loss
'''Regression loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def
__init__
(
self
):
super
(
RegLoss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
,
wight_
=
None
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
loss
=
_reg_loss
(
pred
,
target
,
mask
,
wight_
)
return
loss
class
RegL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
RegL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
def
__init__
(
self
):
super
(
RegL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
class
NormRegL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NormRegL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
pred
=
pred
/
(
target
+
1e-4
)
target
=
target
*
0
+
1
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
def
__init__
(
self
):
super
(
NormRegL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
pred
=
pred
/
(
target
+
1e-4
)
target
=
target
*
0
+
1
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
class
RegWeightedL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
RegWeightedL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
def
__init__
(
self
):
super
(
RegWeightedL1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
float
()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'sum'
)
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
class
L1Loss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
L1Loss
,
self
).
__init__
()
def
__init__
(
self
):
super
(
L1Loss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'elementwise_mean'
)
return
loss
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'elementwise_mean'
)
return
loss
class
BinRotLoss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
BinRotLoss
,
self
).
__init__
()
def
__init__
(
self
):
super
(
BinRotLoss
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
,
rotbin
,
rotres
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
loss
=
compute_rot_loss
(
pred
,
rotbin
,
rotres
,
mask
)
return
loss
def
forward
(
self
,
output
,
mask
,
ind
,
rotbin
,
rotres
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
loss
=
compute_rot_loss
(
pred
,
rotbin
,
rotres
,
mask
)
return
loss
def
compute_res_loss
(
output
,
target
):
return
F
.
smooth_l1_loss
(
output
,
target
,
reduction
=
'elementwise_mean'
)
# TODO: weight
def
compute_bin_loss
(
output
,
target
,
mask
):
mask
=
mask
.
expand_as
(
output
)
output
=
output
*
mask
.
float
()
return
F
.
cross_entropy
(
output
,
target
,
reduction
=
'elementwise_mean'
)
def
compute_rot_loss
(
output
,
target_bin
,
target_res
,
mask
):
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
# bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
...
...
@@ -234,17 +254,17 @@ def compute_rot_loss(output, target_bin, target_res, mask):
valid_output1
=
torch
.
index_select
(
output
,
0
,
idx1
.
long
())
valid_target_res1
=
torch
.
index_select
(
target_res
,
0
,
idx1
.
long
())
loss_sin1
=
compute_res_loss
(
valid_output1
[:,
2
],
torch
.
sin
(
valid_target_res1
[:,
0
]))
valid_output1
[:,
2
],
torch
.
sin
(
valid_target_res1
[:,
0
]))
loss_cos1
=
compute_res_loss
(
valid_output1
[:,
3
],
torch
.
cos
(
valid_target_res1
[:,
0
]))
valid_output1
[:,
3
],
torch
.
cos
(
valid_target_res1
[:,
0
]))
loss_res
+=
loss_sin1
+
loss_cos1
if
target_bin
[:,
1
].
nonzero
().
shape
[
0
]
>
0
:
idx2
=
target_bin
[:,
1
].
nonzero
()[:,
0
]
valid_output2
=
torch
.
index_select
(
output
,
0
,
idx2
.
long
())
valid_target_res2
=
torch
.
index_select
(
target_res
,
0
,
idx2
.
long
())
loss_sin2
=
compute_res_loss
(
valid_output2
[:,
6
],
torch
.
sin
(
valid_target_res2
[:,
1
]))
valid_output2
[:,
6
],
torch
.
sin
(
valid_target_res2
[:,
1
]))
loss_cos2
=
compute_res_loss
(
valid_output2
[:,
7
],
torch
.
cos
(
valid_target_res2
[:,
1
]))
valid_output2
[:,
7
],
torch
.
cos
(
valid_target_res2
[:,
1
]))
loss_res
+=
loss_sin2
+
loss_cos2
return
loss_bin1
+
loss_bin2
+
loss_res
src/lib/trains/base_trainer.py
View file @
69f4ba48
...
...
@@ -17,7 +17,7 @@ class ModleWithLoss(torch.nn.Module):
def
forward
(
self
,
batch
):
outputs
=
self
.
model
(
batch
[
'input'
])
loss
,
loss_stats
=
self
.
loss
(
outputs
,
batch
)
# 输入
loss
,
loss_stats
=
self
.
loss
(
outputs
,
batch
)
# 输入
return
outputs
[
-
1
],
loss
,
loss_stats
...
...
src/lib/trains/multi_pose.py
View file @
69f4ba48
...
...
@@ -47,6 +47,7 @@ class MultiPoseLoss(torch.nn.Module):
batch
[
'hps'
].
detach
().
cpu
().
numpy
(),
batch
[
'ind'
].
detach
().
cpu
().
numpy
(),
opt
.
output_res
,
opt
.
output_res
)).
to
(
opt
.
device
)
if
opt
.
eval_oracle_hp_offset
:
output
[
'hp_offset'
]
=
torch
.
from_numpy
(
gen_oracle_map
(
batch
[
'hp_offset'
].
detach
().
cpu
().
numpy
(),
...
...
@@ -56,10 +57,13 @@ class MultiPoseLoss(torch.nn.Module):
# 1. focal loss,求目标的中心,
hm_loss
+=
self
.
crit
(
output
[
'hm'
],
batch
[
'hm'
])
/
opt
.
num_stacks
if
opt
.
wh_weight
>
0
:
wh_loss
+=
self
.
crit_reg
(
output
[
'wh'
],
batch
[
'reg_mask'
],
# 2. 人脸bbox高度和宽度的loss
# 2. 人脸bbox高度和宽度的loss
wh_loss
+=
self
.
crit_reg
(
output
[
'wh'
],
batch
[
'reg_mask'
],
batch
[
'ind'
],
batch
[
'wh'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
if
opt
.
reg_offset
and
opt
.
off_weight
>
0
:
off_loss
+=
self
.
crit_reg
(
output
[
'hm_offset'
],
batch
[
'reg_mask'
],
# 3. 人脸bbox中心点下采样,所需要的偏差补偿
# 3. 人脸bbox中心点下采样,所需要的偏差补偿
off_loss
+=
self
.
crit_reg
(
output
[
'hm_offset'
],
batch
[
'reg_mask'
],
batch
[
'ind'
],
batch
[
'hm_offset'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
if
opt
.
dense_hp
:
...
...
@@ -68,14 +72,17 @@ class MultiPoseLoss(torch.nn.Module):
batch
[
'dense_hps'
]
*
batch
[
'dense_hps_mask'
])
/
mask_weight
)
/
opt
.
num_stacks
else
:
lm_loss
+=
self
.
crit_kp
(
output
[
'landmarks'
],
batch
[
'hps_mask'
],
# 4. 关节点的偏移
# 4. 关节点的偏移
lm_loss
+=
self
.
crit_kp
(
output
[
'landmarks'
],
batch
[
'hps_mask'
],
batch
[
'ind'
],
batch
[
'landmarks'
])
/
opt
.
num_stacks
# if opt.reg_hp_offset and opt.off_weight > 0:
# 关节点的中心偏移
# 关节点的中心偏移
# if opt.reg_hp_offset and opt.off_weight > 0:
# hp_offset_loss += self.crit_reg(
# output['hp_offset'], batch['hp_mask'],
# batch['hp_ind'], batch['hp_offset']) / opt.num_stacks
# if opt.hm_hp and opt.hm_hp_weight > 0: # 关节点的热力图
# 关节点的热力图
# if opt.hm_hp and opt.hm_hp_weight > 0:
# hm_hp_loss += self.crit_hm_hp(
# output['hm_hp'], batch['hm_hp']) / opt.num_stacks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment