Commit 69f4ba48 authored by chenych's avatar chenych
Browse files

update readme

parent e0491117
......@@ -4,8 +4,9 @@
## 模型结构
CenterFace是一种人脸检测算法,采用了轻量级网络mobileNetV2作为主干网络,结合特征金字塔网络(FPN)实现anchor free的人脸检测。
![Architecture of the CenterFace](Architecture of the CenterFace.png)
<div align=center>
<img src="./Architecture of the CenterFace.png"/>
</div>
## 算法原理
CenterFace模型是一种基于单阶段人脸检测算法,作者借鉴了CenterNet的思想,将人脸检测转换为标准点问题,根据人脸中心点来回归人脸框的大小和五个标志点。
......@@ -53,7 +54,9 @@ pip3 install -r requirements.txt
WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
![datasets](datasets.png)
<div align=center>
<img src="./datasets.png"/>
</div>
下载图片红框中三个数据并解压,也可直接点击下面链接直接下载:
......@@ -99,25 +102,15 @@ WIDER_FACE:http://shuoyang1213.me/WIDERFACE/index.html
│ └── 61--Street_Battle
```
<<<<<<< HEAD
2. 如果是使用WIDER_train、WIDER_val数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json, val_wider_face.json重命名为val_face.json即可,无需进行标注文件格式转换;
反之,需要将训练图片/验证图片对应的人脸关键点标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
=======
2. 如果是使用WIDER_train数据, 可直接将./datasets/labels/下的train_wider_face.json重命名为train_face.json即可,无需进行标注文件格式转换;反之,需要将训练图片对应的人脸关键点标注信息文件(xxxx.txt),放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
反之,需要将训练图片/验证图片对应的人脸标注信息文件train.txt/val.txt,放置于 ./datasets/annotations/下(train存放训练图片的标注文件,val存放验证图片的标注文件),存放目录结构如下:
```
├── annotations
│ ├── train
<<<<<<< HEAD
│ ├── train.txt
│ ├── val
│ ├── val.txt
=======
│ ├── xxx.txt
│ ├── val
│ ├── xxx.txt
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
```
特别地,标注信息的格式为:
......@@ -128,7 +121,8 @@ x, y, w, h, left_eye_x, left_eye_y, flag, right_eye_x, right_eye_y, flag, nose_x
```
举个例子:
train_keypoints_widerface.txt是wider_face训练数据集的标注信息
./datasets/annotations/train/train.txt是wider_face训练数据集的标注信息
```
# 0--Parade/0_Parade_marchingband_1_849.jpg
449 330 122 149 488.906 373.643 0.0 542.089 376.442 0.0 515.031 412.83 0.0 485.174 425.893 0.0 538.357 431.491 0.0 0.82
......@@ -142,11 +136,7 @@ cd ./datasets
python gen_data.py
```
<<<<<<< HEAD
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json、val_face.json
=======
执行完成后会在./datasets/labels下生成训练数据的标注文件 train_face.json
>>>>>>> c9fac124caf657ec86fd6a58ca90e1da5f88f6cb
## 训练
......@@ -166,18 +156,23 @@ bash train_multi.sh
## 推理
#### 单卡推理
```
cd ./src
python test_wider_face.py
```
## result
![Result](draw_img.jpg)
<div align=center>
<img src="./draw_img.jpg"/>
</div>
### 精度
WIDER_FACE验证集上的测试结果如下
| Method | Easy(p) | Medium(p) | Hard(p)|
| Method | Easy(AP) | Medium(AP) | Hard(AP)|
|:--------:| :--------:| :---------:| :------:|
| ours(one scale) | 0.9264 | 0.9133 | 0.7479 |
| original | 0.922 | 0.911 | 0.782|
......
......@@ -2,13 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pycocotools.coco as coco
from pycocotools.cocoeval import COCOeval
import numpy as np
import json
import os
from pycocotools.cocoeval import COCOeval
import torch.utils.data as data
import numpy as np
import pycocotools.coco as coco
class FACEHP(data.Dataset):
......@@ -19,7 +20,8 @@ class FACEHP(data.Dataset):
dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.28863828, 0.27408164, 0.27809835],
dtype=np.float32).reshape(1, 1, 3)
flip_idx = [[0, 1], [3, 4]] # 翻转的关键点在关键点矩阵中的索引
# 翻转的关键点在关键点矩阵中的索引
flip_idx = [[0, 1], [3, 4]]
def __init__(self, opt, split):
super(FACEHP, self).__init__()
......@@ -32,7 +34,8 @@ class FACEHP(data.Dataset):
self.acc_idxs = [1, 2, 3, 4]
self.data_dir = opt.data_dir
_ann_name = {'train': 'train', 'val': 'val'}
self.img_dir = os.path.join(self.data_dir, f'images/{_ann_name[split]}') # 训练图片所在地址
# 图片所在地址
self.img_dir = os.path.join(self.data_dir, f'images/{_ann_name[split]}')
print('===>', self.img_dir)
if split == 'val':
self.annot_path = os.path.join(
......
......@@ -19,33 +19,33 @@ def _left_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
shape = heat.shape
shape = heat.shape
heat = heat.reshape(-1, heat.shape[3])
heat = heat.transpose(1, 0).contiguous()
ret = heat.clone()
for i in range(1, heat.shape[0]):
inds = (heat[i] >= heat[i - 1])
ret[i] += ret[i - 1] * inds.float()
return (ret - heat).transpose(1, 0).reshape(shape)
return (ret - heat).transpose(1, 0).reshape(shape)
def _right_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
shape = heat.shape
shape = heat.shape
heat = heat.reshape(-1, heat.shape[3])
heat = heat.transpose(1, 0).contiguous()
ret = heat.clone()
for i in range(heat.shape[0] - 2, -1, -1):
inds = (heat[i] >= heat[i +1])
ret[i] += ret[i + 1] * inds.float()
return (ret - heat).transpose(1, 0).reshape(shape)
return (ret - heat).transpose(1, 0).reshape(shape)
def _top_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
heat = heat.transpose(3, 2)
heat = heat.transpose(3, 2)
shape = heat.shape
heat = heat.reshape(-1, heat.shape[3])
heat = heat.transpose(1, 0).contiguous()
......@@ -59,7 +59,7 @@ def _bottom_aggregate(heat):
'''
heat: batchsize x channels x h x w
'''
heat = heat.transpose(3, 2)
heat = heat.transpose(3, 2)
shape = heat.shape
heat = heat.reshape(-1, heat.shape[3])
heat = heat.transpose(1, 0).contiguous()
......@@ -92,7 +92,7 @@ def _topk(scores, K=40):
'''
def _topk_channel(scores, K=40):
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
......@@ -103,13 +103,13 @@ def _topk_channel(scores, K=40):
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) # 前100个点
# 前100个点
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / K).int()
topk_inds = _gather_feat(
......@@ -121,8 +121,8 @@ def _topk(scores, K=40):
def agnex_ct_decode(
t_heat, l_heat, b_heat, r_heat, ct_heat,
t_regr=None, l_regr=None, b_regr=None, r_regr=None,
t_heat, l_heat, b_heat, r_heat, ct_heat,
t_regr=None, l_regr=None, b_regr=None, r_regr=None,
K=40, scores_thresh=0.1, center_thresh=0.1, aggr_weight=0.0, num_dets=1000
):
batch, cat, height, width = t_heat.size()
......@@ -134,19 +134,19 @@ def agnex_ct_decode(
r_heat = torch.sigmoid(r_heat)
ct_heat = torch.sigmoid(ct_heat)
'''
if aggr_weight > 0:
if aggr_weight > 0:
t_heat = _h_aggregate(t_heat, aggr_weight=aggr_weight)
l_heat = _v_aggregate(l_heat, aggr_weight=aggr_weight)
b_heat = _h_aggregate(b_heat, aggr_weight=aggr_weight)
r_heat = _v_aggregate(r_heat, aggr_weight=aggr_weight)
# perform nms on heatmaps
t_heat = _nms(t_heat)
l_heat = _nms(l_heat)
b_heat = _nms(b_heat)
r_heat = _nms(r_heat)
t_heat[t_heat > 1] = 1
l_heat[l_heat > 1] = 1
b_heat[b_heat > 1] = 1
......@@ -156,9 +156,9 @@ def agnex_ct_decode(
l_scores, l_inds, _, l_ys, l_xs = _topk(l_heat, K=K)
b_scores, b_inds, _, b_ys, b_xs = _topk(b_heat, K=K)
r_scores, r_inds, _, r_ys, r_xs = _topk(r_heat, K=K)
ct_heat_agn, ct_clses = torch.max(ct_heat, dim=1, keepdim=True)
# import pdb; pdb.set_trace()
t_ys = t_ys.view(batch, K, 1, 1, 1).expand(batch, K, K, K, K)
......@@ -240,7 +240,7 @@ def agnex_ct_decode(
b_ys = b_ys + 0.5
r_xs = r_xs + 0.5
r_ys = r_ys + 0.5
bboxes = torch.stack((l_xs, t_ys, r_xs, b_ys), dim=5)
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds)
......@@ -266,14 +266,14 @@ def agnex_ct_decode(
r_ys = _gather_feat(r_ys, inds).float()
detections = torch.cat([bboxes, scores, t_xs, t_ys, l_xs, l_ys,
detections = torch.cat([bboxes, scores, t_xs, t_ys, l_xs, l_ys,
b_xs, b_ys, r_xs, r_ys, clses], dim=2)
return detections
def exct_decode(
t_heat, l_heat, b_heat, r_heat, ct_heat,
t_regr=None, l_regr=None, b_regr=None, r_regr=None,
t_heat, l_heat, b_heat, r_heat, ct_heat,
t_regr=None, l_regr=None, b_regr=None, r_regr=None,
K=40, scores_thresh=0.1, center_thresh=0.1, aggr_weight=0.0, num_dets=1000
):
batch, cat, height, width = t_heat.size()
......@@ -285,18 +285,18 @@ def exct_decode(
ct_heat = torch.sigmoid(ct_heat)
'''
if aggr_weight > 0:
if aggr_weight > 0:
t_heat = _h_aggregate(t_heat, aggr_weight=aggr_weight)
l_heat = _v_aggregate(l_heat, aggr_weight=aggr_weight)
b_heat = _h_aggregate(b_heat, aggr_weight=aggr_weight)
r_heat = _v_aggregate(r_heat, aggr_weight=aggr_weight)
# perform nms on heatmaps
t_heat = _nms(t_heat)
l_heat = _nms(l_heat)
b_heat = _nms(b_heat)
r_heat = _nms(r_heat)
t_heat[t_heat > 1] = 1
l_heat[l_heat > 1] = 1
b_heat[b_heat > 1] = 1
......@@ -392,7 +392,7 @@ def exct_decode(
b_ys = b_ys + 0.5
r_xs = r_xs + 0.5
r_ys = r_ys + 0.5
bboxes = torch.stack((l_xs, t_ys, r_xs, b_ys), dim=5)
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds)
......@@ -418,7 +418,7 @@ def exct_decode(
r_ys = _gather_feat(r_ys, inds).float()
detections = torch.cat([bboxes, scores, t_xs, t_ys, l_xs, l_ys,
detections = torch.cat([bboxes, scores, t_xs, t_ys, l_xs, l_ys,
b_xs, b_ys, r_xs, r_ys, clses], dim=2)
......@@ -429,7 +429,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
......@@ -439,7 +439,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
rot = _tranpose_and_gather_feat(rot, inds)
rot = rot.view(batch, K, 8)
depth = _tranpose_and_gather_feat(depth, inds)
......@@ -450,7 +450,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
scores = scores.view(batch, K, 1)
xs = xs.view(batch, K, 1)
ys = ys.view(batch, K, 1)
if wh is not None:
wh = _tranpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 2)
......@@ -459,7 +459,7 @@ def ddd_decode(heat, rot, depth, dim, wh=None, reg=None, K=40):
else:
detections = torch.cat(
[xs, ys, scores, rot, depth, dim, clses], dim=2)
return detections
def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
......@@ -468,7 +468,7 @@ def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat) # 3 * 3 区域的最大值滤波
scores, inds, clses, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
......@@ -487,12 +487,12 @@ def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
wh = wh.view(batch, K, 2)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
detections = torch.cat([bboxes, scores, clses], dim=2)
return detections
def multi_pose_decode(
......@@ -521,9 +521,9 @@ def multi_pose_decode(
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
if hm_hp is not None:
hm_hp = _nms(hm_hp) # 第二次:通过关节点热力图求得关节点的中心点
......@@ -541,7 +541,7 @@ def multi_pose_decode(
else:
hm_xs = hm_xs + 0.5
hm_ys = hm_ys + 0.5
mask = (hm_score > thresh).float() # 选置信度大于0.1的
hm_score = (1 - mask) * -1 + mask * hm_score
hm_ys = (1 - mask) * (-10000) + mask * hm_ys
......@@ -570,7 +570,7 @@ def multi_pose_decode(
kps = kps.permute(0, 2, 1, 3).contiguous().view(
batch, K, num_joints * 2)
detections = torch.cat([bboxes, scores, kps, clses], dim=2) # box:4+score:1+kpoints:10+class:1=16
return detections
......
......@@ -15,61 +15,63 @@ import torch.nn.functional as F
def _slow_neg_loss(pred, gt):
'''focal loss from CornerNet'''
pos_inds = gt.eq(1)
neg_inds = gt.lt(1)
'''focal loss from CornerNet'''
pos_inds = gt.eq(1)
neg_inds = gt.lt(1)
neg_weights = torch.pow(1 - gt[neg_inds], 4)
neg_weights = torch.pow(1 - gt[neg_inds], 4)
loss = 0
pos_pred = pred[pos_inds]
neg_pred = pred[neg_inds]
loss = 0
pos_pred = pred[pos_inds]
neg_pred = pred[neg_inds]
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if pos_pred.nelement() == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
if pos_pred.nelement() == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
neg_weights = torch.pow(1 - gt, 4)
loss = 0
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
neg_weights * neg_inds
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
def _not_faster_neg_loss(pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
num_pos = pos_inds.float().sum()
num_pos = pos_inds.float().sum()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
......@@ -80,141 +82,159 @@ def _not_faster_neg_loss(pred, gt):
if num_pos > 0:
all_loss /= num_pos
loss -= all_loss
loss -= all_loss
return loss
def _slow_reg_loss(regr, gt_regr, mask):
num = mask.float().sum()
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr)
regr = regr[mask]
regr = regr[mask]
gt_regr = gt_regr[mask]
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduction='sum')
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
def _reg_loss(regr, gt_regr, mask, wight_=None):
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr).float()
regr = regr * mask
gt_regr = gt_regr * mask
if wight_ is not None:
wight_ = wight_.unsqueeze(2).expand_as(gt_regr).float()
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduction='none')
regr_loss *= wight_
regr_loss = regr_loss.sum()
else:
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduction='sum')
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr).float()
regr = regr * mask
gt_regr = gt_regr * mask
if wight_ is not None:
wight_ = wight_.unsqueeze(2).expand_as(gt_regr).float()
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, reduce=False)
regr_loss = nn.functional.smooth_l1_loss(
regr, gt_regr, reduction='none')
regr_loss *= wight_
regr_loss = regr_loss.sum()
else:
regr_loss = nn.functional.smooth_l1_loss(
regr, gt_regr, reduction='sum')
# regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
class FocalLoss(nn.Module):
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
def forward(self, out, target):
return self.neg_loss(out, target)
def forward(self, out, target):
return self.neg_loss(out, target)
class RegLoss(nn.Module):
'''Regression loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def __init__(self):
super(RegLoss, self).__init__()
def forward(self, output, mask, ind, target, wight_=None):
pred = _tranpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask, wight_)
return loss
'''Regression loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def __init__(self):
super(RegLoss, self).__init__()
def forward(self, output, mask, ind, target, wight_=None):
pred = _tranpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask, wight_)
return loss
class RegL1Loss(nn.Module):
def __init__(self):
super(RegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
def __init__(self):
super(RegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class NormRegL1Loss(nn.Module):
def __init__(self):
super(NormRegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
pred = pred / (target + 1e-4)
target = target * 0 + 1
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
def __init__(self):
super(NormRegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
pred = pred / (target + 1e-4)
target = target * 0 + 1
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class RegWeightedL1Loss(nn.Module):
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.float()
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
# loss = F.l1_loss(pred * mask, target * mask, size_average=False)
loss = loss / (mask.sum() + 1e-4)
return loss
class L1Loss(nn.Module):
def __init__(self):
super(L1Loss, self).__init__()
def __init__(self):
super(L1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
loss = F.l1_loss(pred * mask, target * mask,
reduction='elementwise_mean')
return loss
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
return loss
class BinRotLoss(nn.Module):
def __init__(self):
super(BinRotLoss, self).__init__()
def __init__(self):
super(BinRotLoss, self).__init__()
def forward(self, output, mask, ind, rotbin, rotres):
pred = _tranpose_and_gather_feat(output, ind)
loss = compute_rot_loss(pred, rotbin, rotres, mask)
return loss
def forward(self, output, mask, ind, rotbin, rotres):
pred = _tranpose_and_gather_feat(output, ind)
loss = compute_rot_loss(pred, rotbin, rotres, mask)
return loss
def compute_res_loss(output, target):
return F.smooth_l1_loss(output, target, reduction='elementwise_mean')
# TODO: weight
def compute_bin_loss(output, target, mask):
mask = mask.expand_as(output)
output = output * mask.float()
return F.cross_entropy(output, target, reduction='elementwise_mean')
def compute_rot_loss(output, target_bin, target_res, mask):
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
# bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
......@@ -234,17 +254,17 @@ def compute_rot_loss(output, target_bin, target_res, mask):
valid_output1 = torch.index_select(output, 0, idx1.long())
valid_target_res1 = torch.index_select(target_res, 0, idx1.long())
loss_sin1 = compute_res_loss(
valid_output1[:, 2], torch.sin(valid_target_res1[:, 0]))
valid_output1[:, 2], torch.sin(valid_target_res1[:, 0]))
loss_cos1 = compute_res_loss(
valid_output1[:, 3], torch.cos(valid_target_res1[:, 0]))
valid_output1[:, 3], torch.cos(valid_target_res1[:, 0]))
loss_res += loss_sin1 + loss_cos1
if target_bin[:, 1].nonzero().shape[0] > 0:
idx2 = target_bin[:, 1].nonzero()[:, 0]
valid_output2 = torch.index_select(output, 0, idx2.long())
valid_target_res2 = torch.index_select(target_res, 0, idx2.long())
loss_sin2 = compute_res_loss(
valid_output2[:, 6], torch.sin(valid_target_res2[:, 1]))
valid_output2[:, 6], torch.sin(valid_target_res2[:, 1]))
loss_cos2 = compute_res_loss(
valid_output2[:, 7], torch.cos(valid_target_res2[:, 1]))
valid_output2[:, 7], torch.cos(valid_target_res2[:, 1]))
loss_res += loss_sin2 + loss_cos2
return loss_bin1 + loss_bin2 + loss_res
......@@ -17,7 +17,7 @@ class ModleWithLoss(torch.nn.Module):
def forward(self, batch):
outputs = self.model(batch['input'])
loss, loss_stats = self.loss(outputs, batch) # 输入
loss, loss_stats = self.loss(outputs, batch) # 输入
return outputs[-1], loss, loss_stats
......
......@@ -47,6 +47,7 @@ class MultiPoseLoss(torch.nn.Module):
batch['hps'].detach().cpu().numpy(),
batch['ind'].detach().cpu().numpy(),
opt.output_res, opt.output_res)).to(opt.device)
if opt.eval_oracle_hp_offset:
output['hp_offset'] = torch.from_numpy(gen_oracle_map(
batch['hp_offset'].detach().cpu().numpy(),
......@@ -56,10 +57,13 @@ class MultiPoseLoss(torch.nn.Module):
# 1. focal loss,求目标的中心,
hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
if opt.wh_weight > 0:
wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], # 2. 人脸bbox高度和宽度的loss
# 2. 人脸bbox高度和宽度的loss
wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
batch['ind'], batch['wh'], batch['wight_mask']) / opt.num_stacks
if opt.reg_offset and opt.off_weight > 0:
off_loss += self.crit_reg(output['hm_offset'], batch['reg_mask'], # 3. 人脸bbox中心点下采样,所需要的偏差补偿
# 3. 人脸bbox中心点下采样,所需要的偏差补偿
off_loss += self.crit_reg(output['hm_offset'], batch['reg_mask'],
batch['ind'], batch['hm_offset'], batch['wight_mask']) / opt.num_stacks
if opt.dense_hp:
......@@ -68,14 +72,17 @@ class MultiPoseLoss(torch.nn.Module):
batch['dense_hps'] * batch['dense_hps_mask']) /
mask_weight) / opt.num_stacks
else:
lm_loss += self.crit_kp(output['landmarks'], batch['hps_mask'], # 4. 关节点的偏移
# 4. 关节点的偏移
lm_loss += self.crit_kp(output['landmarks'], batch['hps_mask'],
batch['ind'], batch['landmarks']) / opt.num_stacks
# if opt.reg_hp_offset and opt.off_weight > 0: # 关节点的中心偏移
# 关节点的中心偏移
# if opt.reg_hp_offset and opt.off_weight > 0:
# hp_offset_loss += self.crit_reg(
# output['hp_offset'], batch['hp_mask'],
# batch['hp_ind'], batch['hp_offset']) / opt.num_stacks
# if opt.hm_hp and opt.hm_hp_weight > 0: # 关节点的热力图
# 关节点的热力图
# if opt.hm_hp and opt.hm_hp_weight > 0:
# hm_hp_loss += self.crit_hm_hp(
# output['hm_hp'], batch['hm_hp']) / opt.num_stacks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment