Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
CenterFace_pytorch
Commits
69f4ba48
Commit
69f4ba48
authored
Nov 07, 2023
by
chenych
Browse files
update readme
parent
e0491117
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
235 additions
and
210 deletions
+235
-210
README.md
README.md
+16
-21
src/lib/datasets/dataset/centerface_hp.py
src/lib/datasets/dataset/centerface_hp.py
+8
-5
src/lib/models/decode.py
src/lib/models/decode.py
+39
-39
src/lib/models/losses.py
src/lib/models/losses.py
+158
-138
src/lib/trains/base_trainer.py
src/lib/trains/base_trainer.py
+1
-1
src/lib/trains/multi_pose.py
src/lib/trains/multi_pose.py
+13
-6
No files found.
README.md
View file @
69f4ba48
...
@@ -4,8 +4,9 @@
...
@@ -4,8 +4,9 @@
## 模型结构
## 模型结构
CenterFace是一种人脸检测算法,采用了轻量级网络mobileNetV2作为主干网络,结合特征金字塔网络(FPN)实现anchor free的人脸检测。
CenterFace是一种人脸检测算法,采用了轻量级网络mobileNetV2作为主干网络,结合特征金字塔网络(FPN)实现anchor free的人脸检测。

<div
align=
center
>
<img
src=
"./Architecture of the CenterFace.png"
/>
</div>
## 算法原理
## 算法原理
CenterFace模型是一种基于单阶段人脸检测算法,作者借鉴了CenterNet的思想,将人脸检测转换为标准点问题,根据人脸中心点来回归人脸框的大小和五个标志点。
CenterFace模型是一种基于单阶段人脸检测算法,作者借鉴了CenterNet的思想,将人脸检测转换为标准点问题,根据人脸中心点来回归人脸框的大小和五个标志点。
...
@@ -53,7 +54,9 @@ pip3 install -r requirements.txt
...
@@ -53,7 +54,9 @@ pip3 install -r requirements.txt
WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html

<div
align=
center
>
<img
src=
"./datasets.png"
/>
</div>
下载图片红框中三个数据并解压,也可直接点击下面链接直接下载:
下载图片红框中三个数据并解压,也可直接点击下面链接直接下载:
...
@@ -99,25 +102,15 @@ WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
...
@@ -99,25 +102,15 @@ WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
│ └── 61--Street_Battle
│ └── 61--Street_Battle
```
```
<<<<<<< HEAD
2.
如果是使用WIDER_train、WIDER_val数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json, val_wider_face.json重命名为val_face.json即可,无需进行标注文件格式转换;
2.
如果是使用WIDER_train、WIDER_val数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json, val_wider_face.json重命名为val_face.json即可,无需进行标注文件格式转换;
反之,需要将训练图片/验证图片对应的人脸关键点标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
反之,需要将训练图片/验证图片对应的人脸标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
=======
2.
如果是使用WIDER_train数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json即可,无需进行标注文件格式转换;反之,需要将训练图片对应的人脸关键点标注信息文件(xxxx.txt),放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
```
```
├── annotations
├── annotations
│ ├── train
│ ├── train
<<<<<<< HEAD
│ ├── train.txt
│ ├── train.txt
│ ├── val
│ ├── val
│ ├── val.txt
│ ├── val.txt
=======
│ ├── xxx.txt
│ ├── val
│ ├── xxx.txt
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
```
```
特别地,标注信息的格式为:
特别地,标注信息的格式为:
...
@@ -128,7 +121,8 @@ x, y, w, h, left_eye_x, left_eye_y, flag, right_eye_x, right_eye_y, flag, nose_x
...
@@ -128,7 +121,8 @@ x, y, w, h, left_eye_x, left_eye_y, flag, right_eye_x, right_eye_y, flag, nose_x
```
```
举个例子:
举个例子:
train_keypoints_widerface.txt是wider_face训练数据集的标注信息
./datasets/annotations/train/train.txt是wider_face训练数据集的标注信息
```
```
# 0--Parade/0_Parade_marchingband_1_849.jpg
# 0--Parade/0_Parade_marchingband_1_849.jpg
449 330 122 149 488.906 373.643 0.0 542.089 376.442 0.0 515.031 412.83 0.0 485.174 425.893 0.0 538.357 431.491 0.0 0.82
449 330 122 149 488.906 373.643 0.0 542.089 376.442 0.0 515.031 412.83 0.0 485.174 425.893 0.0 538.357 431.491 0.0 0.82
...
@@ -142,11 +136,7 @@ cd ./datasets
...
@@ -142,11 +136,7 @@ cd ./datasets
python gen_data.py
python gen_data.py
```
```
<<<<<<< HEAD
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json、val_face.json
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json、val_face.json
=======
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
## 训练
## 训练
...
@@ -166,18 +156,23 @@ bash train_multi.sh
...
@@ -166,18 +156,23 @@ bash train_multi.sh
## 推理
## 推理
#### 单卡推理
#### 单卡推理
```
```
cd ./src
cd ./src
python test_wider_face.py
python test_wider_face.py
```
```
## result
## result

<div
align=
center
>
<img
src=
"./draw_img.jpg"
/>
</div>
### 精度
### 精度
WIDER_FACE验证集上的测试结果如下
WIDER_FACE验证集上的测试结果如下
| Method | Easy(
p
) | Medium(
p
) | Hard(
p
)|
| Method | Easy(
AP
) | Medium(
AP
) | Hard(
AP
)|
|:--------:| :--------:| :---------:| :------:|
|:--------:| :--------:| :---------:| :------:|
| ours(one scale) | 0.9264 | 0.9133 | 0.7479 |
| ours(one scale) | 0.9264 | 0.9133 | 0.7479 |
| original | 0.922 | 0.911 | 0.782|
| original | 0.922 | 0.911 | 0.782|
...
...
src/lib/datasets/dataset/centerface_hp.py
View file @
69f4ba48
...
@@ -2,13 +2,14 @@ from __future__ import absolute_import
...
@@ -2,13 +2,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
pycocotools.coco
as
coco
from
pycocotools.cocoeval
import
COCOeval
import
numpy
as
np
import
json
import
json
import
os
import
os
from
pycocotools.cocoeval
import
COCOeval
import
torch.utils.data
as
data
import
torch.utils.data
as
data
import
numpy
as
np
import
pycocotools.coco
as
coco
class
FACEHP
(
data
.
Dataset
):
class
FACEHP
(
data
.
Dataset
):
...
@@ -19,7 +20,8 @@ class FACEHP(data.Dataset):
...
@@ -19,7 +20,8 @@ class FACEHP(data.Dataset):
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
std
=
np
.
array
([
0.28863828
,
0.27408164
,
0.27809835
],
std
=
np
.
array
([
0.28863828
,
0.27408164
,
0.27809835
],
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
3
)
flip_idx
=
[[
0
,
1
],
[
3
,
4
]]
# 翻转的关键点在关键点矩阵中的索引
# 翻转的关键点在关键点矩阵中的索引
flip_idx
=
[[
0
,
1
],
[
3
,
4
]]
def
__init__
(
self
,
opt
,
split
):
def
__init__
(
self
,
opt
,
split
):
super
(
FACEHP
,
self
).
__init__
()
super
(
FACEHP
,
self
).
__init__
()
...
@@ -32,7 +34,8 @@ class FACEHP(data.Dataset):
...
@@ -32,7 +34,8 @@ class FACEHP(data.Dataset):
self
.
acc_idxs
=
[
1
,
2
,
3
,
4
]
self
.
acc_idxs
=
[
1
,
2
,
3
,
4
]
self
.
data_dir
=
opt
.
data_dir
self
.
data_dir
=
opt
.
data_dir
_ann_name
=
{
'train'
:
'train'
,
'val'
:
'val'
}
_ann_name
=
{
'train'
:
'train'
,
'val'
:
'val'
}
self
.
img_dir
=
os
.
path
.
join
(
self
.
data_dir
,
f
'images/
{
_ann_name
[
split
]
}
'
)
# 训练图片所在地址
# 图片所在地址
self
.
img_dir
=
os
.
path
.
join
(
self
.
data_dir
,
f
'images/
{
_ann_name
[
split
]
}
'
)
print
(
'===>'
,
self
.
img_dir
)
print
(
'===>'
,
self
.
img_dir
)
if
split
==
'val'
:
if
split
==
'val'
:
self
.
annot_path
=
os
.
path
.
join
(
self
.
annot_path
=
os
.
path
.
join
(
...
...
src/lib/models/decode.py
View file @
69f4ba48
...
@@ -103,8 +103,8 @@ def _topk_channel(scores, K=40):
...
@@ -103,8 +103,8 @@ def _topk_channel(scores, K=40):
def
_topk
(
scores
,
K
=
40
):
def
_topk
(
scores
,
K
=
40
):
batch
,
cat
,
height
,
width
=
scores
.
size
()
batch
,
cat
,
height
,
width
=
scores
.
size
()
# 前100个点
topk_scores
,
topk_inds
=
torch
.
topk
(
scores
.
view
(
batch
,
cat
,
-
1
),
K
)
# 前100个点
topk_scores
,
topk_inds
=
torch
.
topk
(
scores
.
view
(
batch
,
cat
,
-
1
),
K
)
topk_inds
=
topk_inds
%
(
height
*
width
)
topk_inds
=
topk_inds
%
(
height
*
width
)
topk_ys
=
(
topk_inds
/
width
).
int
().
float
()
topk_ys
=
(
topk_inds
/
width
).
int
().
float
()
...
...
src/lib/models/losses.py
View file @
69f4ba48
...
@@ -54,7 +54,8 @@ def _neg_loss(pred, gt):
...
@@ -54,7 +54,8 @@ def _neg_loss(pred, gt):
loss
=
0
loss
=
0
pos_loss
=
torch
.
log
(
pred
)
*
torch
.
pow
(
1
-
pred
,
2
)
*
pos_inds
pos_loss
=
torch
.
log
(
pred
)
*
torch
.
pow
(
1
-
pred
,
2
)
*
pos_inds
neg_loss
=
torch
.
log
(
1
-
pred
)
*
torch
.
pow
(
pred
,
2
)
*
neg_weights
*
neg_inds
neg_loss
=
torch
.
log
(
1
-
pred
)
*
torch
.
pow
(
pred
,
2
)
*
\
neg_weights
*
neg_inds
num_pos
=
pos_inds
.
float
().
sum
()
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
pos_loss
=
pos_loss
.
sum
()
...
@@ -66,6 +67,7 @@ def _neg_loss(pred, gt):
...
@@ -66,6 +67,7 @@ def _neg_loss(pred, gt):
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
return
loss
def
_not_faster_neg_loss
(
pred
,
gt
):
def
_not_faster_neg_loss
(
pred
,
gt
):
pos_inds
=
gt
.
eq
(
1
).
float
()
pos_inds
=
gt
.
eq
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
...
@@ -83,6 +85,7 @@ def _not_faster_neg_loss(pred, gt):
...
@@ -83,6 +85,7 @@ def _not_faster_neg_loss(pred, gt):
loss
-=
all_loss
loss
-=
all_loss
return
loss
return
loss
def
_slow_reg_loss
(
regr
,
gt_regr
,
mask
):
def
_slow_reg_loss
(
regr
,
gt_regr
,
mask
):
num
=
mask
.
float
().
sum
()
num
=
mask
.
float
().
sum
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
)
...
@@ -94,6 +97,7 @@ def _slow_reg_loss(regr, gt_regr, mask):
...
@@ -94,6 +97,7 @@ def _slow_reg_loss(regr, gt_regr, mask):
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
return
regr_loss
return
regr_loss
def
_reg_loss
(
regr
,
gt_regr
,
mask
,
wight_
=
None
):
def
_reg_loss
(
regr
,
gt_regr
,
mask
,
wight_
=
None
):
''' L1 regression loss
''' L1 regression loss
Arguments:
Arguments:
...
@@ -110,18 +114,22 @@ def _reg_loss(regr, gt_regr, mask, wight_=None):
...
@@ -110,18 +114,22 @@ def _reg_loss(regr, gt_regr, mask, wight_=None):
if
wight_
is
not
None
:
if
wight_
is
not
None
:
wight_
=
wight_
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
wight_
=
wight_
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'none'
)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'none'
)
regr_loss
*=
wight_
regr_loss
*=
wight_
regr_loss
=
regr_loss
.
sum
()
regr_loss
=
regr_loss
.
sum
()
else
:
else
:
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'sum'
)
regr_loss
=
nn
.
functional
.
smooth_l1_loss
(
regr
,
gt_regr
,
reduction
=
'sum'
)
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
regr_loss
=
regr_loss
/
(
num
+
1e-4
)
return
regr_loss
return
regr_loss
class
FocalLoss
(
nn
.
Module
):
class
FocalLoss
(
nn
.
Module
):
'''nn.Module warpper for focal loss'''
'''nn.Module warpper for focal loss'''
def
__init__
(
self
):
def
__init__
(
self
):
super
(
FocalLoss
,
self
).
__init__
()
super
(
FocalLoss
,
self
).
__init__
()
self
.
neg_loss
=
_neg_loss
self
.
neg_loss
=
_neg_loss
...
@@ -129,6 +137,7 @@ class FocalLoss(nn.Module):
...
@@ -129,6 +137,7 @@ class FocalLoss(nn.Module):
def
forward
(
self
,
out
,
target
):
def
forward
(
self
,
out
,
target
):
return
self
.
neg_loss
(
out
,
target
)
return
self
.
neg_loss
(
out
,
target
)
class
RegLoss
(
nn
.
Module
):
class
RegLoss
(
nn
.
Module
):
'''Regression loss for an output tensor
'''Regression loss for an output tensor
Arguments:
Arguments:
...
@@ -137,6 +146,7 @@ class RegLoss(nn.Module):
...
@@ -137,6 +146,7 @@ class RegLoss(nn.Module):
ind (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
target (batch x max_objects x dim)
'''
'''
def
__init__
(
self
):
def
__init__
(
self
):
super
(
RegLoss
,
self
).
__init__
()
super
(
RegLoss
,
self
).
__init__
()
...
@@ -145,6 +155,7 @@ class RegLoss(nn.Module):
...
@@ -145,6 +155,7 @@ class RegLoss(nn.Module):
loss
=
_reg_loss
(
pred
,
target
,
mask
,
wight_
)
loss
=
_reg_loss
(
pred
,
target
,
mask
,
wight_
)
return
loss
return
loss
class
RegL1Loss
(
nn
.
Module
):
class
RegL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
RegL1Loss
,
self
).
__init__
()
super
(
RegL1Loss
,
self
).
__init__
()
...
@@ -158,6 +169,7 @@ class RegL1Loss(nn.Module):
...
@@ -158,6 +169,7 @@ class RegL1Loss(nn.Module):
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
return
loss
class
NormRegL1Loss
(
nn
.
Module
):
class
NormRegL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NormRegL1Loss
,
self
).
__init__
()
super
(
NormRegL1Loss
,
self
).
__init__
()
...
@@ -174,6 +186,7 @@ class NormRegL1Loss(nn.Module):
...
@@ -174,6 +186,7 @@ class NormRegL1Loss(nn.Module):
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
return
loss
class
RegWeightedL1Loss
(
nn
.
Module
):
class
RegWeightedL1Loss
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
RegWeightedL1Loss
,
self
).
__init__
()
super
(
RegWeightedL1Loss
,
self
).
__init__
()
...
@@ -187,6 +200,7 @@ class RegWeightedL1Loss(nn.Module):
...
@@ -187,6 +200,7 @@ class RegWeightedL1Loss(nn.Module):
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
return
loss
class
L1Loss
(
nn
.
Module
):
class
L1Loss
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
L1Loss
,
self
).
__init__
()
super
(
L1Loss
,
self
).
__init__
()
...
@@ -194,9 +208,11 @@ class L1Loss(nn.Module):
...
@@ -194,9 +208,11 @@ class L1Loss(nn.Module):
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
def
forward
(
self
,
output
,
mask
,
ind
,
target
):
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
pred
=
_tranpose_and_gather_feat
(
output
,
ind
)
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
pred
).
float
()
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'elementwise_mean'
)
loss
=
F
.
l1_loss
(
pred
*
mask
,
target
*
mask
,
reduction
=
'elementwise_mean'
)
return
loss
return
loss
class
BinRotLoss
(
nn
.
Module
):
class
BinRotLoss
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
BinRotLoss
,
self
).
__init__
()
super
(
BinRotLoss
,
self
).
__init__
()
...
@@ -206,15 +222,19 @@ class BinRotLoss(nn.Module):
...
@@ -206,15 +222,19 @@ class BinRotLoss(nn.Module):
loss
=
compute_rot_loss
(
pred
,
rotbin
,
rotres
,
mask
)
loss
=
compute_rot_loss
(
pred
,
rotbin
,
rotres
,
mask
)
return
loss
return
loss
def
compute_res_loss
(
output
,
target
):
def
compute_res_loss
(
output
,
target
):
return
F
.
smooth_l1_loss
(
output
,
target
,
reduction
=
'elementwise_mean'
)
return
F
.
smooth_l1_loss
(
output
,
target
,
reduction
=
'elementwise_mean'
)
# TODO: weight
# TODO: weight
def
compute_bin_loss
(
output
,
target
,
mask
):
def
compute_bin_loss
(
output
,
target
,
mask
):
mask
=
mask
.
expand_as
(
output
)
mask
=
mask
.
expand_as
(
output
)
output
=
output
*
mask
.
float
()
output
=
output
*
mask
.
float
()
return
F
.
cross_entropy
(
output
,
target
,
reduction
=
'elementwise_mean'
)
return
F
.
cross_entropy
(
output
,
target
,
reduction
=
'elementwise_mean'
)
def
compute_rot_loss
(
output
,
target_bin
,
target_res
,
mask
):
def
compute_rot_loss
(
output
,
target_bin
,
target_res
,
mask
):
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
# bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
# bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
...
...
src/lib/trains/base_trainer.py
View file @
69f4ba48
src/lib/trains/multi_pose.py
View file @
69f4ba48
...
@@ -47,6 +47,7 @@ class MultiPoseLoss(torch.nn.Module):
...
@@ -47,6 +47,7 @@ class MultiPoseLoss(torch.nn.Module):
batch
[
'hps'
].
detach
().
cpu
().
numpy
(),
batch
[
'hps'
].
detach
().
cpu
().
numpy
(),
batch
[
'ind'
].
detach
().
cpu
().
numpy
(),
batch
[
'ind'
].
detach
().
cpu
().
numpy
(),
opt
.
output_res
,
opt
.
output_res
)).
to
(
opt
.
device
)
opt
.
output_res
,
opt
.
output_res
)).
to
(
opt
.
device
)
if
opt
.
eval_oracle_hp_offset
:
if
opt
.
eval_oracle_hp_offset
:
output
[
'hp_offset'
]
=
torch
.
from_numpy
(
gen_oracle_map
(
output
[
'hp_offset'
]
=
torch
.
from_numpy
(
gen_oracle_map
(
batch
[
'hp_offset'
].
detach
().
cpu
().
numpy
(),
batch
[
'hp_offset'
].
detach
().
cpu
().
numpy
(),
...
@@ -56,10 +57,13 @@ class MultiPoseLoss(torch.nn.Module):
...
@@ -56,10 +57,13 @@ class MultiPoseLoss(torch.nn.Module):
# 1. focal loss,求目标的中心,
# 1. focal loss,求目标的中心,
hm_loss
+=
self
.
crit
(
output
[
'hm'
],
batch
[
'hm'
])
/
opt
.
num_stacks
hm_loss
+=
self
.
crit
(
output
[
'hm'
],
batch
[
'hm'
])
/
opt
.
num_stacks
if
opt
.
wh_weight
>
0
:
if
opt
.
wh_weight
>
0
:
wh_loss
+=
self
.
crit_reg
(
output
[
'wh'
],
batch
[
'reg_mask'
],
# 2. 人脸bbox高度和宽度的loss
# 2. 人脸bbox高度和宽度的loss
wh_loss
+=
self
.
crit_reg
(
output
[
'wh'
],
batch
[
'reg_mask'
],
batch
[
'ind'
],
batch
[
'wh'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
batch
[
'ind'
],
batch
[
'wh'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
if
opt
.
reg_offset
and
opt
.
off_weight
>
0
:
if
opt
.
reg_offset
and
opt
.
off_weight
>
0
:
off_loss
+=
self
.
crit_reg
(
output
[
'hm_offset'
],
batch
[
'reg_mask'
],
# 3. 人脸bbox中心点下采样,所需要的偏差补偿
# 3. 人脸bbox中心点下采样,所需要的偏差补偿
off_loss
+=
self
.
crit_reg
(
output
[
'hm_offset'
],
batch
[
'reg_mask'
],
batch
[
'ind'
],
batch
[
'hm_offset'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
batch
[
'ind'
],
batch
[
'hm_offset'
],
batch
[
'wight_mask'
])
/
opt
.
num_stacks
if
opt
.
dense_hp
:
if
opt
.
dense_hp
:
...
@@ -68,14 +72,17 @@ class MultiPoseLoss(torch.nn.Module):
...
@@ -68,14 +72,17 @@ class MultiPoseLoss(torch.nn.Module):
batch
[
'dense_hps'
]
*
batch
[
'dense_hps_mask'
])
/
batch
[
'dense_hps'
]
*
batch
[
'dense_hps_mask'
])
/
mask_weight
)
/
opt
.
num_stacks
mask_weight
)
/
opt
.
num_stacks
else
:
else
:
lm_loss
+=
self
.
crit_kp
(
output
[
'landmarks'
],
batch
[
'hps_mask'
],
# 4. 关节点的偏移
# 4. 关节点的偏移
lm_loss
+=
self
.
crit_kp
(
output
[
'landmarks'
],
batch
[
'hps_mask'
],
batch
[
'ind'
],
batch
[
'landmarks'
])
/
opt
.
num_stacks
batch
[
'ind'
],
batch
[
'landmarks'
])
/
opt
.
num_stacks
# 关节点的中心偏移
# if opt.reg_hp_offset and opt.off_weight > 0:
# 关节点的中心偏移
# if opt.reg_hp_offset and opt.off_weight > 0:
# hp_offset_loss += self.crit_reg(
# hp_offset_loss += self.crit_reg(
# output['hp_offset'], batch['hp_mask'],
# output['hp_offset'], batch['hp_mask'],
# batch['hp_ind'], batch['hp_offset']) / opt.num_stacks
# batch['hp_ind'], batch['hp_offset']) / opt.num_stacks
# if opt.hm_hp and opt.hm_hp_weight > 0: # 关节点的热力图
# 关节点的热力图
# if opt.hm_hp and opt.hm_hp_weight > 0:
# hm_hp_loss += self.crit_hm_hp(
# hm_hp_loss += self.crit_hm_hp(
# output['hm_hp'], batch['hm_hp']) / opt.num_stacks
# output['hm_hp'], batch['hm_hp']) / opt.num_stacks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment