Commit c320b6ef authored by zhenyi's avatar zhenyi
Browse files

tf2 detection

parent 0fc002df
MIT License
Copyright (c) 2020 JiaQi Xu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## SSD: Single-Shot MultiBox Detector目标检测模型在TF2当中的实现
---
**2021年10月12日更新:**
**进行了大幅度的更新,对代码的模块进行修改,加了大量注释。**
**2021年2月8日更新:**
**加入letterbox_image的选项,关闭letterbox_image后网络的map一般可以得到提升。**
## 目录
1. [性能情况 Performance](#性能情况)
2. [所需环境 Environment](#所需环境)
3. [文件下载 Download](#文件下载)
4. [预测步骤 How2predict](#预测步骤)
5. [训练步骤 How2train](#训练步骤)
6. [评估步骤 How2eval](#评估步骤)
7. [参考资料 Reference](#Reference)
## 性能情况
| 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 |
| :-----: | :-----: | :------: | :------: | :------: | :-----: |
| VOC07+12 | [ssd_weights.h5](https://github.com/bubbliiiing/ssd-tf2/releases/download/v1.0/ssd_weights.h5) | VOC-Test07 | 300x300| - | 77.1
| VOC07++12+COCO | [ssd_weights_coco_07+12.h5](https://github.com/bubbliiiing/ssd-tf2/releases/download/v1.0/ssd_weights_coco_07+12.h5) | VOC-Test12 | 300x300| - | 79.4
## 所需环境
tensorflow-gpu==2.2.0
## 文件下载
训练所需的ssd_weights.h5和主干的权值可以在百度云下载。
链接: https://pan.baidu.com/s/1Ddk5UcZS5Dm4qechwGJDlA
提取码: 1k5d
VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分:
链接: https://pan.baidu.com/s/1YuBbBKxm2FGgTU5OfaeC5A
提取码: uack
## 训练步骤
### a、训练VOC07+12数据集
1. 数据集的准备
**本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录**
2. 数据集的处理
修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。
3. 开始网络训练
train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。
4. 训练结果预测
训练结果预测需要用到两个文件,分别是ssd.py和predict.py。我们首先需要去ssd.py里面修改model_path以及classes_path,这两个参数必须要修改。
**model_path指向训练好的权值文件,在logs文件夹里。
classes_path指向检测类别所对应的txt。**
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
### b、训练自己的数据集
1. 数据集的准备
**本文使用VOC格式进行训练,训练前需要自己制作好数据集,**
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
2. 数据集的处理
在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。
训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
model_data/cls_classes.txt文件内容为:
```python
cat
dog
...
```
修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。
3. 开始网络训练
**训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。**
**classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!**
修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。
4. 训练结果预测
训练结果预测需要用到两个文件,分别是ssd.py和predict.py。在ssd.py里面修改model_path以及classes_path。
**model_path指向训练好的权值文件,在logs文件夹里。
classes_path指向检测类别所对应的txt。**
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
## 预测步骤
### a、使用预训练权重
1. 下载完库后解压,在百度网盘下载,放入model_data,运行predict.py,输入
```python
img/street.jpg
```
2. 在predict.py里面进行设置可以进行fps测试和video视频检测。
### b、使用自己训练的权重
1. 按照训练步骤训练。
2. 在ssd.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**
```python
_defaults = {
#--------------------------------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------#
"model_path" : 'model_data/ssd_weights.h5',
"classes_path" : 'model_data/voc_classes.txt',
#---------------------------------------------------------------------#
# 用于预测的图像大小,和train时使用同一个即可
#---------------------------------------------------------------------#
"input_shape" : [300, 300],
#---------------------------------------------------------------------#
# 只有得分大于置信度的预测框会被保留下来
#---------------------------------------------------------------------#
"confidence" : 0.5,
#---------------------------------------------------------------------#
# 非极大抑制所用到的nms_iou大小
#---------------------------------------------------------------------#
"nms_iou" : 0.45,
#---------------------------------------------------------------------#
# 用于指定先验框的大小
#---------------------------------------------------------------------#
'anchors_size' : [30, 60, 111, 162, 213, 264, 315],
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image" : False,
}
```
3. 运行predict.py,输入
```python
img/street.jpg
```
4. 在predict.py里面进行设置可以进行fps测试和video视频检测。
## 评估步骤
### a、评估VOC07+12的测试集
1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
2. 在ssd.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
### b、评估自己的数据集
1. 本文使用VOC格式进行评估。
2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
4. 在ssd.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
## Reference
https://github.com/Cartucho/mAP
https://github.com/pierluigiferrari/ssd_keras
https://github.com/kuhung/SSD_keras
import os
import xml.etree.ElementTree as ET
import tensorflow as tf
from PIL import Image
from tqdm import tqdm
from ssd import SSD
from utils.utils import get_classes
from utils.utils_map import get_coco_map, get_map
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if __name__ == "__main__":
'''
Recall和Precision不像AP是一个面积的概念,在门限值不同时,网络的Recall和Precision值是不同的。
map计算结果中的Recall和Precision代表的是当预测时,门限置信度为0.5时,所对应的Recall和Precision值。
此处获得的./map_out/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低,
目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。
'''
#------------------------------------------------------------------------------------------------------------------#
# map_mode用于指定该文件运行时计算的内容
# map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
# map_mode为1代表仅仅获得预测结果。
# map_mode为2代表仅仅获得真实框。
# map_mode为3代表仅仅计算VOC_map。
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
#-------------------------------------------------------------------------------------------------------------------#
map_mode = 0
#-------------------------------------------------------#
# 此处的classes_path用于指定需要测量VOC_map的类别
# 一般情况下与训练和预测所用的classes_path一致即可
#-------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
#-------------------------------------------------------#
# MINOVERLAP用于指定想要获得的mAP0.x
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
#-------------------------------------------------------#
MINOVERLAP = 0.5
#-------------------------------------------------------#
# map_vis用于指定是否开启VOC_map计算的可视化
#-------------------------------------------------------#
map_vis = False
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'
#-------------------------------------------------------#
# 结果输出的文件夹,默认为map_out
#-------------------------------------------------------#
map_out_path = 'map_out'
image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
if not os.path.exists(map_out_path):
os.makedirs(map_out_path)
if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
os.makedirs(os.path.join(map_out_path, 'ground-truth'))
if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
os.makedirs(os.path.join(map_out_path, 'detection-results'))
if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
os.makedirs(os.path.join(map_out_path, 'images-optional'))
class_names, _ = get_classes(classes_path)
if map_mode == 0 or map_mode == 1:
print("Load model.")
ssd = SSD(confidence = 0.01, nms_iou = 0.5)
print("Load model done.")
print("Get predict result.")
for image_id in tqdm(image_ids):
image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
image = Image.open(image_path)
if map_vis:
image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
ssd.get_map_txt(image_id, image, class_names, map_out_path)
print("Get predict result done.")
if map_mode == 0 or map_mode == 2:
print("Get ground truth result.")
for image_id in tqdm(image_ids):
with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
for obj in root.findall('object'):
difficult_flag = False
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
if int(difficult)==1:
difficult_flag = True
obj_name = obj.find('name').text
if obj_name not in class_names:
continue
bndbox = obj.find('bndbox')
left = bndbox.find('xmin').text
top = bndbox.find('ymin').text
right = bndbox.find('xmax').text
bottom = bndbox.find('ymax').text
if difficult_flag:
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
else:
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
print("Get ground truth result done.")
if map_mode == 0 or map_mode == 3:
print("Get map.")
get_map(MINOVERLAP, True, path = map_out_path)
print("Get map done.")
if map_mode == 4:
print("Get map.")
get_coco_map(class_names = class_names, path = map_out_path)
print("Get map done.")
\ No newline at end of file
存放训练后的模型
\ No newline at end of file
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
\ No newline at end of file
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras.layers import (Activation, Concatenate, Conv2D, Flatten,
Input, InputSpec, Layer, Reshape)
from tensorflow.keras.models import Model
from nets.vgg import VGG16
class Normalize(Layer):
def __init__(self, scale, **kwargs):
self.axis = 3
self.scale = scale
super(Normalize, self).__init__(**kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
shape = (input_shape[self.axis],)
init_gamma = self.scale * np.ones(shape)
self.gamma = K.variable(init_gamma, name='{}_gamma'.format(self.name))
def call(self, x, mask=None):
output = K.l2_normalize(x, self.axis)
output *= self.gamma
return output
def SSD300(input_shape, num_classes=21):
#---------------------------------#
# 典型的输入大小为[300,300,3]
#---------------------------------#
input_tensor = Input(shape=input_shape)
# net变量里面包含了整个SSD的结构,通过层名可以找到对应的特征层
net = VGG16(input_tensor)
#-----------------------将提取到的主干特征进行处理---------------------------#
# 对conv4_3的通道进行l2标准化处理
# 38,38,512
net['conv4_3_norm'] = Normalize(20, name='conv4_3_norm')(net['conv4_3'])
num_priors = 4
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['conv4_3_norm_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same', name='conv4_3_norm_mbox_loc')(net['conv4_3_norm'])
net['conv4_3_norm_mbox_loc_flat'] = Flatten(name='conv4_3_norm_mbox_loc_flat')(net['conv4_3_norm_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['conv4_3_norm_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv4_3_norm_mbox_conf')(net['conv4_3_norm'])
net['conv4_3_norm_mbox_conf_flat'] = Flatten(name='conv4_3_norm_mbox_conf_flat')(net['conv4_3_norm_mbox_conf'])
# 对fc7层进行处理
# 19,19,1024
num_priors = 6
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['fc7_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3),padding='same',name='fc7_mbox_loc')(net['fc7'])
net['fc7_mbox_loc_flat'] = Flatten(name='fc7_mbox_loc_flat')(net['fc7_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['fc7_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3),padding='same',name='fc7_mbox_conf')(net['fc7'])
net['fc7_mbox_conf_flat'] = Flatten(name='fc7_mbox_conf_flat')(net['fc7_mbox_conf'])
# 对conv6_2进行处理
# 10,10,512
num_priors = 6
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['conv6_2_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same',name='conv6_2_mbox_loc')(net['conv6_2'])
net['conv6_2_mbox_loc_flat'] = Flatten(name='conv6_2_mbox_loc_flat')(net['conv6_2_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['conv6_2_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv6_2_mbox_conf')(net['conv6_2'])
net['conv6_2_mbox_conf_flat'] = Flatten(name='conv6_2_mbox_conf_flat')(net['conv6_2_mbox_conf'])
# 对conv7_2进行处理
# 5,5,256
num_priors = 6
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['conv7_2_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same',name='conv7_2_mbox_loc')(net['conv7_2'])
net['conv7_2_mbox_loc_flat'] = Flatten(name='conv7_2_mbox_loc_flat')(net['conv7_2_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['conv7_2_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv7_2_mbox_conf')(net['conv7_2'])
net['conv7_2_mbox_conf_flat'] = Flatten(name='conv7_2_mbox_conf_flat')(net['conv7_2_mbox_conf'])
# 对conv8_2进行处理
# 3,3,256
num_priors = 4
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['conv8_2_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same',name='conv8_2_mbox_loc')(net['conv8_2'])
net['conv8_2_mbox_loc_flat'] = Flatten(name='conv8_2_mbox_loc_flat')(net['conv8_2_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['conv8_2_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv8_2_mbox_conf')(net['conv8_2'])
net['conv8_2_mbox_conf_flat'] = Flatten(name='conv8_2_mbox_conf_flat')(net['conv8_2_mbox_conf'])
# 对conv9_2进行处理
# 1,1,256
num_priors = 4
# 预测框的处理
# num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
net['conv9_2_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same',name='conv9_2_mbox_loc')(net['conv9_2'])
net['conv9_2_mbox_loc_flat'] = Flatten(name='conv9_2_mbox_loc_flat')(net['conv9_2_mbox_loc'])
# num_priors表示每个网格点先验框的数量,num_classes是所分的类
net['conv9_2_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv9_2_mbox_conf')(net['conv9_2'])
net['conv9_2_mbox_conf_flat'] = Flatten(name='conv9_2_mbox_conf_flat')(net['conv9_2_mbox_conf'])
# 将所有结果进行堆叠
net['mbox_loc'] = Concatenate(axis=1, name='mbox_loc')([net['conv4_3_norm_mbox_loc_flat'],
net['fc7_mbox_loc_flat'],
net['conv6_2_mbox_loc_flat'],
net['conv7_2_mbox_loc_flat'],
net['conv8_2_mbox_loc_flat'],
net['conv9_2_mbox_loc_flat']])
net['mbox_conf'] = Concatenate(axis=1, name='mbox_conf')([net['conv4_3_norm_mbox_conf_flat'],
net['fc7_mbox_conf_flat'],
net['conv6_2_mbox_conf_flat'],
net['conv7_2_mbox_conf_flat'],
net['conv8_2_mbox_conf_flat'],
net['conv9_2_mbox_conf_flat']])
# 8732,4
net['mbox_loc'] = Reshape((-1, 4), name='mbox_loc_final')(net['mbox_loc'])
# 8732,21
net['mbox_conf'] = Reshape((-1, num_classes), name='mbox_conf_logits')(net['mbox_conf'])
net['mbox_conf'] = Activation('softmax', name='mbox_conf_final')(net['mbox_conf'])
# 8732,25
net['predictions'] = Concatenate(axis =-1, name='predictions')([net['mbox_loc'], net['mbox_conf']])
model = Model(net['input'], net['predictions'])
return model
import tensorflow as tf
class MultiboxLoss(object):
def __init__(self, num_classes, alpha=1.0, neg_pos_ratio=3.0,
background_label_id=0, negatives_for_hard=100.0):
self.num_classes = num_classes
self.alpha = alpha
self.neg_pos_ratio = neg_pos_ratio
if background_label_id != 0:
raise Exception('Only 0 as background label id is supported')
self.background_label_id = background_label_id
self.negatives_for_hard = negatives_for_hard
def _l1_smooth_loss(self, y_true, y_pred):
abs_loss = tf.abs(y_true - y_pred)
sq_loss = 0.5 * (y_true - y_pred)**2
l1_loss = tf.where(tf.less(abs_loss, 1.0), sq_loss, abs_loss - 0.5)
return tf.reduce_sum(l1_loss, -1)
def _softmax_loss(self, y_true, y_pred):
y_pred = tf.maximum(y_pred, 1e-7)
softmax_loss = -tf.reduce_sum(y_true * tf.math.log(y_pred),
axis=-1)
return softmax_loss
def compute_loss(self, y_true, y_pred):
# --------------------------------------------- #
# y_true batch_size, 8732, 4 + self.num_classes + 1
# y_pred batch_size, 8732, 4 + self.num_classes
# --------------------------------------------- #
num_boxes = tf.cast(tf.shape(y_true)[1], tf.float32)
# --------------------------------------------- #
# 分类的loss
# batch_size,8732,21 -> batch_size,8732
# --------------------------------------------- #
conf_loss = self._softmax_loss(y_true[:, :, 4:-1],
y_pred[:, :, 4:])
# --------------------------------------------- #
# 框的位置的loss
# batch_size,8732,4 -> batch_size,8732
# --------------------------------------------- #
loc_loss = self._l1_smooth_loss(y_true[:, :, :4],
y_pred[:, :, :4])
# --------------------------------------------- #
# 获取所有的正标签的loss
# --------------------------------------------- #
pos_loc_loss = tf.reduce_sum(loc_loss * y_true[:, :, -1],
axis=1)
pos_conf_loss = tf.reduce_sum(conf_loss * y_true[:, :, -1],
axis=1)
# --------------------------------------------- #
# 每一张图的正样本的个数
# num_pos [batch_size,]
# --------------------------------------------- #
num_pos = tf.reduce_sum(y_true[:, :, -1], axis=-1)
# --------------------------------------------- #
# 每一张图的负样本的个数
# num_neg [batch_size,]
# --------------------------------------------- #
num_neg = tf.minimum(self.neg_pos_ratio * num_pos, num_boxes - num_pos)
# 找到了哪些值是大于0的
pos_num_neg_mask = tf.greater(num_neg, 0)
# --------------------------------------------- #
# 如果所有的图,正样本的数量均为0
# 那么则默认选取100个先验框作为负样本
# --------------------------------------------- #
has_min = tf.cast(tf.reduce_any(pos_num_neg_mask), tf.float32)
num_neg = tf.concat(axis=0, values=[num_neg, [(1 - has_min) * self.negatives_for_hard]])
# --------------------------------------------- #
# 从这里往后,与视频中看到的代码有些许不同。
# 由于以前的负样本选取方式存在一些问题,
# 我对该部分代码进行重构。
# 求整个batch应该的负样本数量总和
# --------------------------------------------- #
num_neg_batch = tf.reduce_sum(tf.boolean_mask(num_neg, tf.greater(num_neg, 0)))
num_neg_batch = tf.cast(num_neg_batch, tf.int32)
# --------------------------------------------- #
# 对预测结果进行判断,如果该先验框没有包含物体
# 那么它的不属于背景的预测概率过大的话
# 就是难分类样本
# --------------------------------------------- #
confs_start = 4 + self.background_label_id + 1
confs_end = confs_start + self.num_classes - 1
# --------------------------------------------- #
# batch_size,8732
# 把不是背景的概率求和,求和后的概率越大
# 代表越难分类。
# --------------------------------------------- #
max_confs = tf.reduce_sum(y_pred[:, :, confs_start:confs_end], axis=2)
# --------------------------------------------------- #
# 只有没有包含物体的先验框才得到保留
# 我们在整个batch里面选取最难分类的num_neg_batch个
# 先验框作为负样本。
# --------------------------------------------------- #
max_confs = tf.reshape(max_confs * (1 - y_true[:, :, -1]), [-1])
_, indices = tf.nn.top_k(max_confs, k=num_neg_batch)
neg_conf_loss = tf.gather(tf.reshape(conf_loss, [-1]), indices)
# 进行归一化
num_pos = tf.where(tf.not_equal(num_pos, 0), num_pos, tf.ones_like(num_pos))
total_loss = tf.reduce_sum(pos_conf_loss) + tf.reduce_sum(neg_conf_loss) + tf.reduce_sum(self.alpha * pos_loc_loss)
total_loss /= tf.reduce_sum(num_pos)
return total_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment