Commit c89bc397 authored by andyjpaddle's avatar andyjpaddle
Browse files

Merge branch 'release/2.5' of https://github.com/PaddlePaddle/PaddleOCR into release/2.5

parents 2274364b 61b03628
......@@ -64,7 +64,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
上述指令中,通过-c 选择训练使用configs/det/det_mv3_db.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
......
......@@ -7,7 +7,8 @@
- [1. 文本检测模型推理](#1-文本检测模型推理)
- [2. 文本识别模型推理](#2-文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#22-多语言模型的推理)
- [2.2 英文识别模型推理](#22-英文识别模型推理)
- [2.3 多语言模型的推理](#23-多语言模型的推理)
- [3. 方向分类模型推理](#3-方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理)
......@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg"
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379)
```
<a name="英文识别模型推理"></a>
### 2.2 英文识别模型推理
英文识别模型推理,可以执行如下命令, 注意修改字典路径:
```
# 下载英文数字识别模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar
tar xf en_PP-OCRv3_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_rec_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
执行命令后,上图的预测结果为:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="多语言模型的推理"></a>
### 2.2 多语言模型的推理
### 2.3 多语言模型的推理
如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
......
......@@ -591,8 +591,9 @@ Metric:
#### 2.2.5 检测蒸馏模型finetune
PP-OCRv3检测蒸馏有两种方式:
- 采用ch_PP-OCRv3_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv3_det_cml.yml,采用CML蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv3_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上相比单独训练Student模型有1%-2%的提升。
> 如果您在自己的场景中没有训练过高精度大模型,或原始PP-OCR模型在您的场景中表现不好,则无法使用CML训练以达到更高精度,更应该采用DML训练
在具体fine-tune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。
......
# 《动手学OCR》电子书
《动手学OCR》是PaddleOCR团队携手复旦大学青年研究员陈智能、中国移动研究院视觉领域资深专家黄文辉等产学研同仁,以及OCR开发者共同打造的结合OCR前沿理论与代码实践的教材。主要特色如下:
《动手学OCR》是PaddleOCR团队携手华中科技大学博导/教授,IAPR Fellow 白翔、复旦大学青年研究员陈智能、中国移动研究院视觉领域资深专家黄文辉、中国工商银行大数据人工智能实验室研究员等产学研同仁,以及OCR开发者共同打造的结合OCR前沿理论与代码实践的教材。主要特色如下:
- 覆盖从文本检测识别到文档分析的OCR全栈技术
- 紧密结合理论实践,跨越代码实现鸿沟,并配套教学视频
......
......@@ -30,11 +30,11 @@ PP-OCR系统pipeline如下:
PP-OCR系统在持续迭代优化,目前已发布PP-OCR和PP-OCRv2两个版本:
PP-OCR从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身(如绿框所示),最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考PP-OCR技术方案 https://arxiv.org/abs/2009.09941
PP-OCR从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身(如绿框所示),最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考[PP-OCR技术报告](https://arxiv.org/abs/2009.09941)
#### PP-OCRv2
PP-OCRv2在PP-OCR的基础上,进一步在5个方面重点优化,检测模型采用CML协同互学习知识蒸馏策略和CopyPaste数据增广策略;识别模型采用LCNet轻量级骨干网络、UDML 改进知识蒸馏策略和[Enhanced CTC loss](./enhanced_ctc_loss.md)损失函数改进(如上图红框所示),进一步在推理速度和预测效果上取得明显提升。更多细节请参考PP-OCRv2[技术报告](https://arxiv.org/abs/2109.03144)
PP-OCRv2在PP-OCR的基础上,进一步在5个方面重点优化,检测模型采用CML协同互学习知识蒸馏策略和CopyPaste数据增广策略;识别模型采用LCNet轻量级骨干网络、UDML 改进知识蒸馏策略和[Enhanced CTC loss](./enhanced_ctc_loss.md)损失函数改进(如上图红框所示),进一步在推理速度和预测效果上取得明显提升。更多细节请参考[PP-OCRv2技术报告](https://arxiv.org/abs/2109.03144)
#### PP-OCRv3
......@@ -48,7 +48,7 @@ PP-OCRv3系统pipeline如下:
<img src="../ppocrv3_framework.png" width="800">
</div>
更多细节请参考PP-OCRv3[技术报告](./PP-OCRv3_introduction.md)
更多细节请参考[PP-OCRv3技术报告](https://arxiv.org/abs/2206.03001v2) 👉[中文简洁版](./PP-OCRv3_introduction.md)
<a name="2"></a>
......
......@@ -201,3 +201,7 @@ im_show.save('result.jpg')
通过本节内容,相信您已经熟练掌握PaddleOCR whl包的使用方法并获得了初步效果。
PaddleOCR是一套丰富领先实用的OCR工具库,打通数据、模型训练、压缩和推理部署全流程,您可以参考[文档教程](../../README_ch.md#文档教程),正式开启PaddleOCR的应用之旅。
<a href="https://trackgit.com">
<img src="https://us-central1-trackgit-analytics.cloudfunctions.net/token/ping/l63dyu35d5dpcg6y3ibl" alt="trackgit-views" />
</a>
......@@ -550,4 +550,4 @@ inference/en_PP-OCRv3_rec/
Q1: 训练模型转inference 模型之后预测效果不一致?
**A**:此类问题出现较多,问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。可以对比训练使用的配置文件中的预处理、后处理和预测时是否存在差异。
**A**:此类问题出现较多,问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。可以对比训练使用的配置文件中的预处理、后处理和预测时是否存在差异。更多内容请参考[FAQ](./FAQ.md#210-%E6%A8%A1%E5%9E%8B%E6%95%88%E6%9E%9C%E4%B8%8E%E6%95%88%E6%9E%9C%E4%B8%8D%E4%B8%80%E8%87%B4).
......@@ -4,8 +4,7 @@
- 半自动标注工具[PPOCRLabelv2](../../PPOCRLabel):新增表格文字图像、图像关键信息抽取任务和不规则文字图像的标注功能;
- OCR产业落地工具集:打通22种训练部署软硬件环境与方式,覆盖企业90%的训练部署环境需求
- 交互式OCR开源电子书[《动手学OCR》](./ocr_book.md),覆盖OCR全栈技术的前沿理论与代码实践,并配套教学视频。
- 2022.5.7 添加对[Weights & Biases](https://docs.wandb.ai/)训练日志记录工具的支持。
- 2021.12.21 《OCR十讲》课程开讲,12月21日起每晚八点半线上授课! 【免费】报名地址:https://aistudio.baidu.com/aistudio/course/introduce/25207
- 2021.12.21 《动手学OCR·十讲》课程开讲,12月21日起每晚八点半线上授课! 【免费】[报名地址](https://aistudio.baidu.com/aistudio/course/introduce/25207)
- 2021.12.21 发布PaddleOCR v2.4。OCR算法新增1种文本检测算法(PSENet),3种文本识别算法(NRTR、SEED、SAR);文档结构化算法新增1种关键信息提取算法(SDMGR),3种DocVQA算法(LayoutLM、LayoutLMv2,LayoutXLM)。
- 2021.9.7 发布PaddleOCR v2.3,发布[PP-OCRv2](#PP-OCRv2),CPU推理速度相比于PP-OCR server提升220%;效果相比于PP-OCR mobile 提升7%。
- 2021.8.3 发布PaddleOCR v2.2,新增文档结构分析[PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README_ch.md)工具包,支持版面分析与表格识别(含Excel导出)。
......@@ -29,7 +28,7 @@
- 2020.7.9 添加支持空格的识别模型,识别效果,预测及训练方式请参考快速开始和文本识别训练相关文档
- 2020.7.9 添加数据增强、学习率衰减策略,具体参考[配置文件](./config.md)
- 2020.6.8 添加[数据集](dataset/datasets.md),并保持持续更新
- 2020.6.5 支持 `attetnion` 模型导出 `inference_model`
- 2020.6.5 支持 `attention` 模型导出 `inference_model`
- 2020.6.5 支持单独预测识别时,输出结果得分
- 2020.5.30 提供超轻量级中文OCR在线体验
- 2020.5.30 模型预测、训练支持Windows系统
......
......@@ -101,7 +101,7 @@ Considering that the features of some channels will be suppressed if the convolu
The recognition module of PP-OCRv3 is optimized based on the text recognition algorithm [SVTR](https://arxiv.org/abs/2205.00159). RNN is abandoned in SVTR, and the context information of the text line image is more effectively mined by introducing the Transformers structure, thereby improving the text recognition ability.
The recognition accuracy of SVTR_inty outperforms PP-OCRv2 recognition model by 5.3%, while the prediction speed nearly 11 times slower. It takes nearly 100ms to predict a text line on CPU. Therefore, as shown in the figure below, PP-OCRv3 adopts the following six optimization strategies to accelerate the recognition model.
The recognition accuracy of SVTR_tiny outperforms PP-OCRv2 recognition model by 5.3%, while the prediction speed nearly 11 times slower. It takes nearly 100ms to predict a text line on CPU. Therefore, as shown in the figure below, PP-OCRv3 adopts the following six optimization strategies to accelerate the recognition model.
<div align="center">
<img src="../ppocr_v3/v3_rec_pipeline.png" width=800>
......@@ -200,7 +200,7 @@ UDML (Unified-Deep Mutual Learning) is a strategy proposed in PP-OCRv2 which is
**(6)UIM:Unlabeled Images Mining**
UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%).
UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%). In practice, we use the full data set to train the high-precision SVTR_Tiny model (acc=82.5%) for data mining. [SVTR_Tiny model download and tutorial](../../applications/高精度中文识别模型.md).
<div align="center">
<img src="../ppocr_v3/UIM.png" width="500">
......
......@@ -51,7 +51,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
In the above instruction, use `-c` to select the training to use the `configs/det/det_mv3_db.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
......
......@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo
- [Text Detection Model Inference](#text-detection-model-inference)
- [Text Recognition Model Inference](#text-recognition-model-inference)
- [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference)
- [2. Multilingual Model Inference](#2-multilingual-model-inference)
- [2. English Recognition Model Inference](#2-english-recognition-model-inference)
- [3. Multilingual Model Inference](#3-multilingual-model-inference)
- [Angle Classification Model Inference](#angle-classification-model-inference)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation)
......@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score)
```bash
Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671)
```
<a name="2-english-recognition-model-inference"></a>
### 2. English Recognition Model Inference
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
For English recognition model inference, you can execute the following commands,you need to specify the dictionary path used by `--rec_char_dict_path`:
### 2. Multilingual Model Inference
```
# download en model:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
After executing the command, the prediction result of the above figure is:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="3-multilingual-model-inference"></a>
### 3. Multilingual Model Inference
If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
......
......@@ -29,10 +29,10 @@ PP-OCR pipeline is as follows:
PP-OCR system is in continuous optimization. At present, PP-OCR and PP-OCRv2 have been released:
PP-OCR adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941).
PP-OCR adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the [PP-OCR technical report](https://arxiv.org/abs/2009.09941).
#### PP-OCRv2
On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (https://arxiv.org/abs/2109.03144).
On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the [PP-OCRv2 technical report](https://arxiv.org/abs/2109.03144).
#### PP-OCRv3
......@@ -46,7 +46,7 @@ PP-OCRv3 pipeline is as follows:
<img src="../ppocrv3_framework.png" width="800">
</div>
For more details, please refer to [PP-OCRv3 technical report](./PP-OCRv3_introduction_en.md).
For more details, please refer to [PP-OCRv3 technical report](https://arxiv.org/abs/2206.03001v2).
<a name="2"></a>
## 2. Features
......
# Paddleocr Package
# PaddleOCR Package
## 1 Get started quickly
### 1.1 install package
......
......@@ -446,7 +446,7 @@ class PaddleOCR(predict_system.TextSystem):
"""
ocr with paddleocr
args:
img: img for ocr, support ndarray, img_path and list or ndarray
img: img for ocr, support ndarray, img_path and list of ndarray
det: use text detection or not. If false, only rec will be exec. Default is True
rec: use text recognition or not. If false, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
......
......@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \
SVTRRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -35,10 +35,12 @@ class CopyPaste(object):
point_num = data['polys'].shape[1]
src_img = data['image']
src_polys = data['polys'].tolist()
src_texts = data['texts']
src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0]
ext_image = ext_data['image']
ext_polys = ext_data['polys']
ext_texts = ext_data['texts']
ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
......@@ -53,7 +55,7 @@ class CopyPaste(object):
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores):
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys)
......@@ -62,6 +64,7 @@ class CopyPaste(object):
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box)
src_texts.append(ext_texts[idx])
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2]
......@@ -70,6 +73,7 @@ class CopyPaste(object):
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img
data['polys'] = src_polys
data['texts'] = src_texts
data['ignore_tags'] = np.array(src_ignores)
return data
......
......@@ -23,7 +23,7 @@ import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
from scipy.spatial import distance as dist
from ppocr.utils.logging import get_logger
......@@ -74,9 +74,10 @@ class DetLabelEncode(object):
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def expand_points_num(self, boxes):
......@@ -443,7 +444,9 @@ class KieLabelEncode(object):
elif 'key_cls' in ann.keys():
labels.append(ann['key_cls'])
else:
raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
raise ValueError(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges.append(ann.get('edge', 0))
ann_infos = dict(
image=data['image'],
......
......@@ -207,6 +207,21 @@ class PRENResizeImg(object):
return data
class SVTRRecResizeImg(object):
def __init__(self, image_shape, padding=True, **kwargs):
self.image_shape = image_shape
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......
......@@ -57,17 +57,27 @@ class CELoss(nn.Layer):
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
if self.mode.lower() == 'kl':
loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
else:
raise ValueError(
"The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
......@@ -95,7 +105,7 @@ class DMLLoss(nn.Layer):
self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js")
self.jskl_loss = KLJSLoss(mode="kl")
def _kldiv(self, x, target):
eps = 1.0e-10
......
......@@ -27,12 +27,12 @@ class CosineEmbeddingLoss(nn.Layer):
self.epsilon = 1e-12
def forward(self, x1, x2, target):
similarity = paddle.fluid.layers.reduce_sum(
similarity = paddle.sum(
x1 * x2, dim=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1) + self.epsilon)
one_list = paddle.full_like(target, fill_value=1)
out = paddle.fluid.layers.reduce_mean(
out = paddle.mean(
paddle.where(
paddle.equal(target, one_list), 1. - similarity,
paddle.maximum(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment