Commit cce6e1bf authored by chenych's avatar chenych
Browse files

First commit.

parents
Pipeline #640 failed with stages
in 0 seconds
*.mdb
*.tar
*.zip
*.eps
*.pdf
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### OSX ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
### Python Patch ###
.venv/
### Python.VirtualEnv Stack ###
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
[Bb]in
[Ii]nclude
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
pip-selfcheck.json
### Windows ###
# Windows thumbnail cache files
Thumbs.db
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
.idea/
.vscode/
output/
exp/
data/
*.pyc
*.mp4
*.zip
\ No newline at end of file
AdelaiDet for non-commercial purposes
(For commercial use, contact chhshen@gmail.com for obtaining a commerical license.)
Copyright (c) 2019 the authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# DeepSolo
## 论文
[DeepSolo: Let Transformer Decoder with Explicit Points Solo for Text Spotting](https://arxiv.org/abs/2211.10772)
[DeepSolo++: Let Transformer Decoder with Explicit Points Solo for Text Spotting](https://arxiv.org/abs/2305.19957)
## 模型结构
一个简洁的类似DETR的基线,允许一个具有显式点的解码器同时进行检测和识别(图 (c)、(f))。
<div align=center>
<img src="./doc/image.png"/>
</div>
## 算法原理
DeepSolo中,编码器在接收到图像特征后,生成由四个Bezier控制点表示的Bezier中心曲线候选和相应的分数,然后,选择前K个评分的候选。对于每个选定的曲线候选,在曲线上均匀采样N个点,这些点的坐标被编码为位置query并将其添加到内容query中形成复合query。接下来,将复合query输入deformable cross-attention解码器收集有用的文本特征。在解码器之后,采用了几个简单的并行预测头(线性层或MLP)将query解码为文本的中心线、边界、script和置信度,从而同时解决检测和识别问题。
<div align=center>
<img src="./doc/DeepSolo.jpg"/>
</div>
## 环境配置
训练需要依赖Detectron2库,编译Detectron2库需要满足 Python ≥ 3.7,PyTorch ≥ 1.8 并且 torchvision 与 PyTorch 版本匹配,gcc & g++ ≥ 5.4。如果想要更快的构建,推荐安装Ninja。
Tips: 如果detectron2安装失败,可尝试以下方式进行安装:
```
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2
```
### Docker(方法一)
-v 路径、docker_name和imageID根据实际情况修改
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/deepsolo_pytorch
pip install -r requirements.txt
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
bash make.sh
```
### Dockerfile(方法二)
-v 路径、docker_name和imageID根据实际情况修改
```
cd ./docker
cp ../requirements.txt requirements.txt
docker build --no-cache -t deepsolo:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/deepsolo_pytorch
pip install -r requirements.txt
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
bash make.sh
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk23.04
python:python3.8
torch:1.13.1
torchvision:0.14.1
```
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
2、其他非特殊库直接按照下面步骤进行安装
```
pip install -r requirements.txt
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
bash make.sh
```
## 数据集
所有的数据集请保存在 deepsolo_pytorch/datasets 下,因数据集较大,请按训练的需求进行选择下载。训练需求详见configs中yaml的DATASETS字段。
### 训练数据集
`[SynthText150K (CurvedSynText150K)]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations(Part1)](https://1drv.ms/u/s!ApEsJ9RIZdBQgQTfQC578sYbkPik?e=2Yz06g) | [annotations(Part2)](https://1drv.ms/u/s!ApEsJ9RIZdBQgQJWqH404p34Wb1m?e=KImg6N)
`[MLT]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQgQBpvuvV2KBBbN64?e=HVTCab)
`[ICDAR2013]` [images](https://1drv.ms/u/s!ApEsJ9RIZdBQgQcK05sWzK3_t26T?e=5jTWAa) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQfbgqFCeiKOrTM0E?e=UMfIQh)
`[ICDAR2015]` [images](https://1drv.ms/u/s!ApEsJ9RIZdBQgQbupfCNqVxtYGna?e=b4TQY2) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQfhGW5JDiNcDxfWQ?e=PZ2JCX)
`[Total-Text]` [images](https://1drv.ms/u/s!ApEsJ9RIZdBQgQjyPyivo_FnjJ1H?e=qgSFYL) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQgQOShwd8O0K5Dd1f?e=GYyPAX)
`[CTW1500]` [images](https://1drv.ms/u/s!ApEsJ9RIZdBQgQlZVAH5AJld3Y9g?e=zgG71Z) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQfPpyzxoFV34zBg4?e=WK20AN)
`[TextOCR]` [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQgQHY3mjH13GRLPGI?e=Dx1O99)
`[Inverse-Text]` [images](https://1drv.ms/u/s!AimBgYV7JjTlgccVhlbD4I3z5QfmsQ?e=myu7Ue) | [annotations](https://1drv.ms/u/s!ApEsJ9RIZdBQf3G4vZpf4QD5NKo?e=xR3GtY)
`[SynChinese130K]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations](https://1drv.ms/u/s!AimBgYV7JjTlgch5W0n1Iv397i0csw?e=Gq8qww)
`[ArT]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations](https://1drv.ms/u/s!AimBgYV7JjTlgch45d0VHNCoPC1jfQ?e=likK00)
`[LSVT]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations](https://1drv.ms/u/s!AimBgYV7JjTlgch7yjmrCSN0TgoO4w?e=NKd5OG)
`[ReCTS]` [images](https://github.com/aim-uofa/AdelaiDet/tree/master/datasets) | [annotations](https://1drv.ms/u/s!AimBgYV7JjTlgch_xZ8otxFWfNgZSg?e=pdq28B)
`[Evaluation ground-truth]` [Link](https://1drv.ms/u/s!ApEsJ9RIZdBQem-MG1TjuRWApyA?e=fVPnmT)
### 验证数据集
```
cd datasets
mkdir evaluation
cd evaluation
wget -O gt_ctw1500.zip https://cloudstor.aarnet.edu.au/plus/s/xU3yeM3GnidiSTr/download
wget -O gt_totaltext.zip https://cloudstor.aarnet.edu.au/plus/s/SFHvin8BLUM4cNd/download
wget -O gt_icdar2015.zip https://drive.google.com/file/d/1wrq_-qIyb_8dhYVlDzLZTTajQzbic82Z/view?usp=sharing
wget -O gt_inversetext.zip https://cloudstor.aarnet.edu.au/plus/s/xU3yeM3GnidiSTr/download
```
### 数据集目录结构
用于正常训练的数据集请按此目录结构进行:
```
├── ./datasets
│ ├── simple
│ ├── test_images
│ ├── train_images
│ ├── test.json
│ └── train.json
│ ├── evaluation
│ ├── gt_totaltext.zip
│ ├── gt_ctw1500.zip
│ ├── gt_icdar2015.zip
│ └── gt_inversetext.zip
│ ├── syntext1
│ ├── train_images
│ └── annotations
│ ├── train_37voc.json
│ └── train_96voc.json
│ ├── syntext2
│ ├── train_images
│ └── annotations
│ ├── train_37voc.json
│ └── train_96voc.json
│ ├── mlt2017
│ ├── train_images
│ ├── train_37voc.json
│ └── train_96voc.json
│ ├── totaltext
│ ├── train_images
│ ├── test_images
│ ├── weak_voc_new.txt
│ ├── weak_voc_pair_list.txt
│ ├── train_37voc.json
│ ├── train_96voc.json
│ └── test.json
│ ├── ic13
│ ├── train_images
│ ├── train_37voc.json
│ └── train_96voc.json
│ ├── ic15
│ ├── train_images
│ ├── test_images
│ ├── new_strong_lexicon
│ ├── strong_lexicon
│ ├── ch4_test_vocabulary.txt
│ ├── ch4_test_vocabulary_new.txt
│ ├── ch4_test_vocabulary_pair_list.txt
│ ├── GenericVocabulary.txt
│ ├── GenericVocabulary_new.txt
│ ├── GenericVocabulary_pair_list.txt
│ ├── train_37voc.json
│ ├── train_96voc.json
│ └── test.json
│ ├── ctw1500
│ ├── train_images
│ ├── test_images
│ ├── weak_voc_new.txt
│ ├── weak_voc_pair_list.txt
│ ├── train_96voc.json
│ └── test.json
│ ├── textocr
│ ├── train_images
│ ├── train_37voc_1.json
│ └── train_37voc_2.json
│ ├── inversetext
│ ├── test_images
│ └── test.json
│ ├── chnsyntext
│ ├── syn_130k_images
│ └── chn_syntext.json
│ ├── ArT
│ ├── rename_artimg_train
│ └── art_train.json
│ ├── LSVT
│ ├── rename_lsvtimg_train
│ └── lsvt_train.json
│ ├── ReCTS
│ ├── ReCTS_train_images # 18,000 images
│ ├── ReCTS_val_images # 2,000 images
│ ├── ReCTS_test_images # 5,000 images
│ ├── rects_train.json
│ ├── rects_val.json
│ └── rects_test.json
```
如果使用自己的数据集,请将数据标注转换成COCO的格式,并在DeepSolo/adet/data/builtin.py代码第18行 _PREDEFINED_SPLITS_TEXT 参数中,参照结构补充自己的数据集。
项目同样提供了迷你数据集simple进行学习。
## 训练
### 单机多卡
Tips: 以下参数请根据实际情况自行修改 train.sh 中的参数设定
--config-file yaml文件配置地址
--num-gpus 训练卡数量
修改后执行:
```
bash train.sh
```
## 推理
Tips:
如需执行自己的预训练模型,请修改配置:
${CONFIG_FILE} yaml文件配置地址(注意修改预训练模型地址)
${IMAGE_PATH} 待测试数据地址
样例执行步骤:
1. 下载CTW1500的预训练模型:
|Backbone|Training Data|Weights|
|:------:|:------:|:------:|
|Res-50|Synth150K+Total-Text+MLT17+IC13+IC15|[OneDrive](https://1drv.ms/u/s!AimBgYV7JjTlgcdtYzwEBGvOH6CiBw?e=trgKFE)|
将预训练模型放在 pretrained_models/CTW1500/ 文件夹下,如果放置于其他地方,请同步修改配置文件中 MODEL.WEIGHTS 地址
2. 将待测试数据存放于 ${IMAGE_PATH} 下,执行
```
bash test.sh
```
推理结果默认保存在test_results文件夹下,可以使用参数 --output 替换结果保存路径。
## result
CTW1500上的结果展示
<div align=center>
<img src="./doc/results.jpg"/>
</div>
### 精度
基于backbone=R50在ctw1500上的测试结果如下表所示:
|Backbone|External Data|Det-P|Det-R|Det-F1|E2E-None|E2E-Full|
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
|Res-50(ours)|Synth150K+Total-Text+MLT17+IC13+IC15|0.9325|0.8475|0.8879|0.6408|0.812|
|Res-50|Synth150K+Total-Text+MLT17+IC13+IC15|0.932|0.85|0.889|0.642|0.814|
## 应用场景
### 算法类别
OCR
### 热点应用行业
政府,交通,物流
## 源码仓库及问题反馈
http://developer.hpccube.com/codes/modelzoo/deepsolo_pytorch.git
## 参考资料
https://github.com/ViTAE-Transformer/DeepSolo.git
from adet import modeling
__version__ = "0.1.1"
from .adet_checkpoint import AdetCheckpointer
__all__ = ["AdetCheckpointer"]
import pickle, os
from fvcore.common.file_io import PathManager
from detectron2.checkpoint import DetectionCheckpointer
class AdetCheckpointer(DetectionCheckpointer):
"""
Same as :class:`DetectronCheckpointer`, but is able to convert models
in AdelaiDet, such as LPF backbone.
"""
def _load_file(self, filename):
if filename.endswith(".pkl"):
with PathManager.open(filename, "rb") as f:
data = pickle.load(f, encoding="latin1")
if "model" in data and "__author__" in data:
# file is in Detectron2 model zoo format
self.logger.info("Reading a file from '{}'".format(data["__author__"]))
return data
else:
# assume file is from Caffe2 / Detectron1 model zoo
if "blobs" in data:
# Detection models have "blobs", but ImageNet models don't
data = data["blobs"]
data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
if "weight_order" in data:
del data["weight_order"]
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
loaded = super()._load_file(filename) # load native pth checkpoint
if "model" not in loaded:
loaded = {"model": loaded}
basename = os.path.basename(filename).lower()
if "lpf" in basename or "dla" in basename:
loaded["matching_heuristics"] = True
return loaded
from .config import get_cfg
__all__ = [
"get_cfg",
]
from detectron2.config import CfgNode
def get_cfg() -> CfgNode:
"""
Get a copy of the default config.
Returns:
a detectron2 CfgNode instance.
"""
from .defaults import _C
return _C.clone()
from detectron2.config.defaults import _C
from detectron2.config import CfgNode as CN
# ---------------------------------------------------------------------------- #
# Additional Configs
# ---------------------------------------------------------------------------- #
_C.MODEL.MOBILENET = False
_C.MODEL.BACKBONE.ANTI_ALIAS = False
_C.MODEL.RESNETS.DEFORM_INTERVAL = 1
_C.INPUT.HFLIP_TRAIN = False
_C.INPUT.CROP.CROP_INSTANCE = True
_C.INPUT.ROTATE = True
_C.MODEL.BASIS_MODULE = CN()
_C.MODEL.BASIS_MODULE.NAME = "ProtoNet"
_C.MODEL.BASIS_MODULE.NUM_BASES = 4
_C.MODEL.BASIS_MODULE.LOSS_ON = False
_C.MODEL.BASIS_MODULE.ANN_SET = "coco"
_C.MODEL.BASIS_MODULE.CONVS_DIM = 128
_C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"]
_C.MODEL.BASIS_MODULE.NORM = "SyncBN"
_C.MODEL.BASIS_MODULE.NUM_CONVS = 3
_C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8
_C.MODEL.BASIS_MODULE.NUM_CLASSES = 80
_C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3
_C.MODEL.TOP_MODULE = CN()
_C.MODEL.TOP_MODULE.NAME = "conv"
_C.MODEL.TOP_MODULE.DIM = 16
# ---------------------------------------------------------------------------- #
# BAText Options
# ---------------------------------------------------------------------------- #
_C.MODEL.BATEXT = CN()
_C.MODEL.BATEXT.VOC_SIZE = 96
_C.MODEL.BATEXT.NUM_CHARS = 25
_C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32)
_C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"]
_C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625)
_C.MODEL.BATEXT.SAMPLING_RATIO = 1
_C.MODEL.BATEXT.CONV_DIM = 256
_C.MODEL.BATEXT.NUM_CONV = 2
_C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc"
_C.MODEL.BATEXT.RECOGNIZER = "attn"
_C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8)
_C.MODEL.BATEXT.USE_COORDCONV = False
_C.MODEL.BATEXT.USE_AET = False
_C.MODEL.BATEXT.CUSTOM_DICT = "" # Path to the class file.
# ---------------------------------------------------------------------------- #
# SwinTransformer Options
# ---------------------------------------------------------------------------- #
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.TYPE = 'tiny'
_C.MODEL.SWIN.DROP_PATH_RATE = 0.2
# ---------------------------------------------------------------------------- #
# ViTAE-v2 Options
# ---------------------------------------------------------------------------- #
_C.MODEL.ViTAEv2 = CN()
_C.MODEL.ViTAEv2.TYPE = 'vitaev2_s'
_C.MODEL.ViTAEv2.DROP_PATH_RATE = 0.2
# ---------------------------------------------------------------------------- #
# (Deformable) Transformer Options
# ---------------------------------------------------------------------------- #
_C.MODEL.TRANSFORMER = CN()
_C.MODEL.TRANSFORMER.ENABLED = False
_C.MODEL.TRANSFORMER.INFERENCE_TH_TEST = 0.4
_C.MODEL.TRANSFORMER.AUX_LOSS = True
_C.MODEL.TRANSFORMER.ENC_LAYERS = 6
_C.MODEL.TRANSFORMER.DEC_LAYERS = 6
_C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 1024
_C.MODEL.TRANSFORMER.HIDDEN_DIM = 256
_C.MODEL.TRANSFORMER.DROPOUT = 0.0
_C.MODEL.TRANSFORMER.NHEADS = 8
_C.MODEL.TRANSFORMER.NUM_QUERIES = 100
_C.MODEL.TRANSFORMER.ENC_N_POINTS = 4
_C.MODEL.TRANSFORMER.DEC_N_POINTS = 4
_C.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE = 6.283185307179586 # 2 PI
_C.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS = 4
_C.MODEL.TRANSFORMER.VOC_SIZE = 37 # a-z + 0-9 + unknown
_C.MODEL.TRANSFORMER.CUSTOM_DICT = "" # Path to the character class file.
_C.MODEL.TRANSFORMER.NUM_POINTS = 25 # the number of point queries for each instance
_C.MODEL.TRANSFORMER.TEMPERATURE = 10000
_C.MODEL.TRANSFORMER.BOUNDARY_HEAD = True # True: with boundary predictions, False: only with center lines
_C.MODEL.TRANSFORMER.LOSS = CN()
_C.MODEL.TRANSFORMER.LOSS.AUX_LOSS = True
_C.MODEL.TRANSFORMER.LOSS.FOCAL_ALPHA = 0.25
_C.MODEL.TRANSFORMER.LOSS.FOCAL_GAMMA = 2.0
# bezier proposal loss
_C.MODEL.TRANSFORMER.LOSS.BEZIER_CLASS_WEIGHT = 1.0
_C.MODEL.TRANSFORMER.LOSS.BEZIER_COORD_WEIGHT = 1.0
_C.MODEL.TRANSFORMER.LOSS.BEZIER_SAMPLE_POINTS = 25
# supervise the sampled on-curve points but not 4 Bezier control points
# target loss
_C.MODEL.TRANSFORMER.LOSS.POINT_CLASS_WEIGHT = 1.0
_C.MODEL.TRANSFORMER.LOSS.POINT_COORD_WEIGHT = 1.0
_C.MODEL.TRANSFORMER.LOSS.POINT_TEXT_WEIGHT = 0.5
_C.MODEL.TRANSFORMER.LOSS.BOUNDARY_WEIGHT = 0.5
_C.SOLVER.OPTIMIZER = "ADAMW"
_C.SOLVER.LR_BACKBONE = 1e-5
_C.SOLVER.LR_BACKBONE_NAMES = []
_C.SOLVER.LR_LINEAR_PROJ_NAMES = []
_C.SOLVER.LR_LINEAR_PROJ_MULT = 0.1
# 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015)
# 1 - Full lexicon (for totaltext)
_C.TEST.LEXICON_TYPE = 1
\ No newline at end of file
from .text_evaluation_all import TextEvaluator
#!/usr/bin/env python2
#encoding: UTF-8
import json
import sys;sys.path.append('./')
import zipfile
import re
import sys
import os
import codecs
import importlib
from io import StringIO
from shapely.geometry import *
def print_help():
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
sys.exit(2)
def load_zip_file_keys(file,fileNameRegExp=''):
"""
Returns an array with the entries of the ZIP file that match with the regular expression.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
"""
try:
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
except :
raise Exception('Error loading the ZIP archive.')
pairs = []
for name in archive.namelist():
addFile = True
keyName = name
if fileNameRegExp!="":
m = re.match(fileNameRegExp,name)
if m == None:
addFile = False
else:
if len(m.groups())>0:
keyName = m.group(1)
if addFile:
pairs.append( keyName )
return pairs
def load_zip_file(file,fileNameRegExp='',allEntries=False):
"""
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
"""
try:
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
except :
raise Exception('Error loading the ZIP archive')
pairs = []
for name in archive.namelist():
addFile = True
keyName = name
if fileNameRegExp!="":
m = re.match(fileNameRegExp,name)
if m == None:
addFile = False
else:
if len(m.groups())>0:
keyName = m.group(1)
if addFile:
pairs.append( [ keyName , archive.read(name)] )
else:
if allEntries:
raise Exception('ZIP entry not valid: %s' %name)
return dict(pairs)
def decode_utf8(raw):
"""
Returns a Unicode object on success, or None on failure
"""
try:
raw = codecs.decode(raw,'utf-8', 'replace')
#extracts BOM if exists
raw = raw.encode('utf8')
if raw.startswith(codecs.BOM_UTF8):
raw = raw.replace(codecs.BOM_UTF8, '', 1)
return raw.decode('utf-8')
except:
return None
def validate_lines_in_file_gt(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
This function validates that all lines of the file calling the Line validation function for each line
"""
utf8File = decode_utf8(file_contents)
if (utf8File is None) :
raise Exception("The file %s is not UTF-8" %fileName)
lines = utf8File.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != ""):
try:
validate_tl_line_gt(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
except Exception as e:
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
This function validates that all lines of the file calling the Line validation function for each line
"""
utf8File = decode_utf8(file_contents)
if (utf8File is None) :
raise Exception("The file %s is not UTF-8" %fileName)
lines = utf8File.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != ""):
try:
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
except Exception as e:
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
def validate_tl_line_gt(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
"""
get_tl_line_values_gt(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
"""
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
def get_tl_line_values_gt(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
Returns values from a textline. Points , [Confidences], [Transcriptions]
"""
confidence = 0.0
transcription = "";
points = []
if LTRB:
# do not use
raise Exception('Not implemented.')
else:
# if withTranscription and withConfidence:
# cors = line.split(',')
# assert(len(cors)%2 -2 == 0), 'num cors should be even.'
# try:
# points = [ float(ic) for ic in cors[:-2]]
# except Exception as e:
# raise(e)
# elif withConfidence:
# cors = line.split(',')
# assert(len(cors)%2 -1 == 0), 'num cors should be even.'
# try:
# points = [ float(ic) for ic in cors[:-1]]
# except Exception as e:
# raise(e)
# elif withTranscription:
# cors = line.split(',')
# assert(len(cors)%2 -1 == 0), 'num cors should be even.'
# try:
# points = [ float(ic) for ic in cors[:-1]]
# except Exception as e:
# raise(e)
# else:
# cors = line.split(',')
# assert(len(cors)%2 == 0), 'num cors should be even.'
# try:
# points = [ float(ic) for ic in cors[:]]
# except Exception as e:
# raise(e)
if withTranscription and withConfidence:
raise('not implemented')
elif withConfidence:
raise('not implemented')
elif withTranscription:
ptr = line.strip().split(',####')
cors = ptr[0].split(',')
recs = ptr[1].strip()
assert(len(cors)%2 == 0), 'num cors should be even.'
try:
points = [ float(ic) for ic in cors[:]]
except Exception as e:
raise(e)
else:
raise('not implemented')
validate_clockwise_points(points)
if (imWidth>0 and imHeight>0):
for ip in range(0, len(points), 2):
validate_point_inside_bounds(points[ip],points[ip+1],imWidth,imHeight);
if withConfidence:
try:
confidence = 1.0
except ValueError:
raise Exception("Confidence value must be a float")
if withTranscription:
# posTranscription = numPoints + (2 if withConfidence else 1)
# transcription = cors[-1].strip()
transcription = recs
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
return points,confidence,transcription
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
Returns values from a textline. Points , [Confidences], [Transcriptions]
"""
confidence = 0.0
transcription = "";
points = []
if LTRB:
# do not use
raise Exception('Not implemented.')
else:
if withTranscription and withConfidence:
raise('not implemented')
elif withConfidence:
raise('not implemented')
elif withTranscription:
ptr = line.strip().split(',####')
cors = ptr[0].split(',')
recs = ptr[1].strip()
assert(len(cors)%2 == 0), 'num cors should be even.'
try:
points = [ float(ic) for ic in cors[:]]
except Exception as e:
raise(e)
else:
raise('not implemented')
# print('det clock wise')
validate_clockwise_points(points)
if (imWidth>0 and imHeight>0):
for ip in range(0, len(points), 2):
validate_point_inside_bounds(points[ip],points[ip+1],imWidth,imHeight);
if withConfidence:
try:
confidence = 1.0
except ValueError:
raise Exception("Confidence value must be a float")
if withTranscription:
# posTranscription = numPoints + (2 if withConfidence else 1)
transcription = recs
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
return points,confidence,transcription
def validate_point_inside_bounds(x,y,imWidth,imHeight):
if(x<0 or x>imWidth):
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
if(y<0 or y>imHeight):
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
def validate_clockwise_points(points):
"""
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
"""
# if len(points) != 8:
# raise Exception("Points list not valid." + str(len(points)))
# point = [
# [int(points[0]) , int(points[1])],
# [int(points[2]) , int(points[3])],
# [int(points[4]) , int(points[5])],
# [int(points[6]) , int(points[7])]
# ]
# edge = [
# ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
# ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
# ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
# ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
# ]
# summatory = edge[0] + edge[1] + edge[2] + edge[3];
# if summatory>0:
# raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
pts = [(points[j], points[j+1]) for j in range(0,len(points),2)]
try:
pdet = Polygon(pts)
except:
assert(0), ('not a valid polygon', pts)
# The polygon should be valid.
if not pdet.is_valid:
assert(0), ('polygon has intersection sides', pts)
pRing = LinearRing(pts)
if pRing.is_ccw:
assert(0), ("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
"""
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
xmin,ymin,xmax,ymax,[confidence],[transcription]
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
"""
pointsList = []
transcriptionsList = []
confidencesList = []
lines = content.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != "") :
points, confidence, transcription = get_tl_line_values_gt(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
pointsList.append(points)
transcriptionsList.append(transcription)
confidencesList.append(confidence)
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
import numpy as np
sorted_ind = np.argsort(-np.array(confidencesList))
confidencesList = [confidencesList[i] for i in sorted_ind]
pointsList = [pointsList[i] for i in sorted_ind]
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
return pointsList,confidencesList,transcriptionsList
def get_tl_line_values_from_file_contents_det(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
"""
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
xmin,ymin,xmax,ymax,[confidence],[transcription]
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
"""
pointsList = []
transcriptionsList = []
confidencesList = []
lines = content.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != "") :
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
pointsList.append(points)
transcriptionsList.append(transcription)
confidencesList.append(confidence)
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
import numpy as np
sorted_ind = np.argsort(-np.array(confidencesList))
confidencesList = [confidencesList[i] for i in sorted_ind]
pointsList = [pointsList[i] for i in sorted_ind]
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
return pointsList,confidencesList,transcriptionsList
def main_evaluation(p,det_file, gt_file, default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
"""
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
Params:
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
validate_data_fn: points to a method that validates the corrct format of the submission
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
"""
# if (p == None):
# p = dict([s[1:].split('=') for s in sys.argv[1:]])
# if(len(sys.argv)<3):
# print_help()
p = {}
p['g'] =gt_file #'tttgt.zip'
p['s'] =det_file #'det.zip'
evalParams = default_evaluation_params_fn()
if 'p' in p.keys():
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
# try:
validate_data_fn(p['g'], p['s'], evalParams)
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
resDict.update(evalData)
# except Exception as e:
# resDict['Message']= str(e)
# resDict['calculated']=False
if 'o' in p:
if not os.path.exists(p['o']):
os.makedirs(p['o'])
resultsOutputname = p['o'] + '/results.zip'
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
del resDict['per_sample']
if 'output_items' in resDict.keys():
del resDict['output_items']
outZip.writestr('method.json',json.dumps(resDict))
if not resDict['calculated']:
if show_result:
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
if 'o' in p:
outZip.close()
return resDict
if 'o' in p:
if per_sample == True:
for k,v in evalData['per_sample'].items():
outZip.writestr( k + '.json',json.dumps(v))
if 'output_items' in evalData.keys():
for k, v in evalData['output_items'].items():
outZip.writestr( k,v)
outZip.close()
# if show_result:
# sys.stdout.write("Calculated!")
# sys.stdout.write('\n')
# sys.stdout.write(json.dumps(resDict['e2e_method']))
# sys.stdout.write('\n')
# sys.stdout.write(json.dumps(resDict['det_only_method']))
# sys.stdout.write('\n')
return resDict
def main_validation(default_evaluation_params_fn,validate_data_fn):
"""
This process validates a method
Params:
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
validate_data_fn: points to a method that validates the corrct format of the submission
"""
try:
p = dict([s[1:].split('=') for s in sys.argv[1:]])
evalParams = default_evaluation_params_fn()
if 'p' in p.keys():
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
validate_data_fn(p['g'], p['s'], evalParams)
print('SUCCESS')
sys.exit(0)
except Exception as e:
print(str(e))
sys.exit(101)
\ No newline at end of file
#!/usr/bin/env python2
#encoding: UTF-8
import json
import sys;sys.path.append('./')
import zipfile
import re
import sys
import os
import codecs
import importlib
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def print_help():
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
sys.exit(2)
def load_zip_file_keys(file,fileNameRegExp=''):
"""
Returns an array with the entries of the ZIP file that match with the regular expression.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
"""
try:
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
except :
raise Exception('Error loading the ZIP archive.')
pairs = []
for name in archive.namelist():
addFile = True
keyName = name
if fileNameRegExp!="":
m = re.match(fileNameRegExp,name)
if m == None:
addFile = False
else:
if len(m.groups())>0:
keyName = m.group(1)
if addFile:
pairs.append( keyName )
return pairs
def load_zip_file(file,fileNameRegExp='',allEntries=False):
"""
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
"""
try:
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
except :
raise Exception('Error loading the ZIP archive')
pairs = []
for name in archive.namelist():
addFile = True
keyName = name
if fileNameRegExp!="":
m = re.match(fileNameRegExp,name)
if m == None:
addFile = False
else:
if len(m.groups())>0:
keyName = m.group(1)
if addFile:
pairs.append( [ keyName , archive.read(name)] )
else:
if allEntries:
raise Exception('ZIP entry not valid: %s' %name)
return dict(pairs)
def decode_utf8(raw):
"""
Returns a Unicode object on success, or None on failure
"""
try:
raw = codecs.decode(raw,'utf-8', 'replace')
#extracts BOM if exists
raw = raw.encode('utf8')
if raw.startswith(codecs.BOM_UTF8):
raw = raw.replace(codecs.BOM_UTF8, '', 1)
return raw.decode('utf-8')
except:
return None
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
This function validates that all lines of the file calling the Line validation function for each line
"""
utf8File = decode_utf8(file_contents)
if (utf8File is None) :
raise Exception("The file %s is not UTF-8" %fileName)
lines = utf8File.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != ""):
try:
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
except Exception as e:
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
"""
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
Returns values from a textline. Points , [Confidences], [Transcriptions]
"""
confidence = 0.0
transcription = "";
points = []
numPoints = 4;
if LTRB:
numPoints = 4;
if withTranscription and withConfidence:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
if m == None :
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
elif withConfidence:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
if m == None :
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
elif withTranscription:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
if m == None :
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
else:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
if m == None :
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
xmin = int(m.group(1))
ymin = int(m.group(2))
xmax = int(m.group(3))
ymax = int(m.group(4))
if(xmax<xmin):
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
if(ymax<ymin):
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
if (imWidth>0 and imHeight>0):
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
else:
numPoints = 8;
if withTranscription and withConfidence:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
if m == None :
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
elif withConfidence:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
if m == None :
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
elif withTranscription:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
if m == None :
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
else:
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
if m == None :
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
validate_clockwise_points(points)
if (imWidth>0 and imHeight>0):
validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
if withConfidence:
try:
confidence = float(m.group(numPoints+1))
except ValueError:
raise Exception("Confidence value must be a float")
if withTranscription:
posTranscription = numPoints + (2 if withConfidence else 1)
transcription = m.group(posTranscription)
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
return points,confidence,transcription
def validate_point_inside_bounds(x,y,imWidth,imHeight):
if(x<0 or x>imWidth):
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
if(y<0 or y>imHeight):
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
def validate_clockwise_points(points):
"""
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
"""
if len(points) != 8:
raise Exception("Points list not valid." + str(len(points)))
point = [
[int(points[0]) , int(points[1])],
[int(points[2]) , int(points[3])],
[int(points[4]) , int(points[5])],
[int(points[6]) , int(points[7])]
]
edge = [
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
]
summatory = edge[0] + edge[1] + edge[2] + edge[3];
if summatory>0:
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
"""
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
xmin,ymin,xmax,ymax,[confidence],[transcription]
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
"""
pointsList = []
transcriptionsList = []
confidencesList = []
lines = content.split( "\r\n" if CRLF else "\n" )
for line in lines:
line = line.replace("\r","").replace("\n","")
if(line != "") :
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
pointsList.append(points)
transcriptionsList.append(transcription)
confidencesList.append(confidence)
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
import numpy as np
sorted_ind = np.argsort(-np.array(confidencesList))
confidencesList = [confidencesList[i] for i in sorted_ind]
pointsList = [pointsList[i] for i in sorted_ind]
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
return pointsList,confidencesList,transcriptionsList
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
"""
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
Params:
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
validate_data_fn: points to a method that validates the corrct format of the submission
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
"""
if (p == None):
p = dict([s[1:].split('=') for s in sys.argv[1:]])
if(len(sys.argv)<3):
print_help()
evalParams = default_evaluation_params_fn()
if 'p' in p.keys():
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
try:
validate_data_fn(p['g'], p['s'], evalParams)
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
resDict.update(evalData)
except Exception as e:
resDict['Message']= str(e)
resDict['calculated']=False
if 'o' in p:
if not os.path.exists(p['o']):
os.makedirs(p['o'])
resultsOutputname = p['o'] + '/results.zip'
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
del resDict['per_sample']
if 'output_items' in resDict.keys():
del resDict['output_items']
outZip.writestr('method.json',json.dumps(resDict))
if not resDict['calculated']:
if show_result:
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
if 'o' in p:
outZip.close()
return resDict
if 'o' in p:
if per_sample == True:
for k,v in evalData['per_sample'].items():
outZip.writestr( k + '.json',json.dumps(v))
if 'output_items' in evalData.keys():
for k, v in evalData['output_items'].items():
outZip.writestr( k,v)
outZip.close()
# if show_result:
# sys.stdout.write("Calculated!")
# sys.stdout.write("\n")
# sys.stdout.write(json.dumps(resDict['e2e_method']))
# sys.stdout.write("\n")
return resDict
def main_validation(default_evaluation_params_fn,validate_data_fn):
"""
This process validates a method
Params:
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
validate_data_fn: points to a method that validates the corrct format of the submission
"""
try:
p = dict([s[1:].split('=') for s in sys.argv[1:]])
evalParams = default_evaluation_params_fn()
if 'p' in p.keys():
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
validate_data_fn(p['g'], p['s'], evalParams)
print('SUCCESS')
sys.exit(0)
except Exception as e:
print(str(e))
sys.exit(101)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# encoding=utf8
from collections import namedtuple
from adet.evaluation import rrc_evaluation_funcs
import importlib
import sys
import math
from rapidfuzz import string_metric
WORD_SPOTTING =True
def evaluation_imports():
"""
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
"""
return {
'Polygon':'plg',
'numpy':'np'
}
def default_evaluation_params():
"""
default_evaluation_params: Default parameters to use for the validation and evaluation.
"""
global WORD_SPOTTING
return {
'IOU_CONSTRAINT' :0.5,
'AREA_PRECISION_CONSTRAINT' :0.5,
'WORD_SPOTTING' :WORD_SPOTTING,
'MIN_LENGTH_CARE_WORD' :3,
'GT_SAMPLE_NAME_2_ID':'([0-9]+).txt',
'DET_SAMPLE_NAME_2_ID':'([0-9]+).txt',
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
'CRLF':False, # Lines are delimited by Windows CRLF format
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
'SPECIAL_CHARACTERS':str('!?.:,*"()·[]/\''),
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
}
def validate_data(gtFilePath, submFilePath, evaluationParams):
"""
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
Validates also that there are no missing files in the folder.
If some error detected, the method raises the error
"""
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
#Validate format of GroundTruth
for k in gt:
rrc_evaluation_funcs.validate_lines_in_file_gt(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
#Validate format of results
for k in subm:
if (k in gt) == False :
raise Exception("The sample %s not present in GT" %k)
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
"""
Method evaluate_method: evaluate method and returns the results
Results. Dictionary with the following values:
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
"""
for module,alias in evaluation_imports().items():
globals()[alias] = importlib.import_module(module)
def polygon_from_points(points):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
num_points = len(points)
# resBoxes=np.empty([1,num_points],dtype='int32')
resBoxes=np.empty([1,num_points],dtype='float32')
for inp in range(0, num_points, 2):
resBoxes[0, int(inp/2)] = float(points[int(inp)])
resBoxes[0, int(inp/2+num_points/2)] = float(points[int(inp+1)])
pointMat = resBoxes[0].reshape([2,int(num_points/2)]).T
return plg.Polygon(pointMat)
def rectangle_to_polygon(rect):
resBoxes=np.empty([1,8],dtype='int32')
resBoxes[0,0]=int(rect.xmin)
resBoxes[0,4]=int(rect.ymax)
resBoxes[0,1]=int(rect.xmin)
resBoxes[0,5]=int(rect.ymin)
resBoxes[0,2]=int(rect.xmax)
resBoxes[0,6]=int(rect.ymin)
resBoxes[0,3]=int(rect.xmax)
resBoxes[0,7]=int(rect.ymax)
pointMat = resBoxes[0].reshape([2,4]).T
return plg.Polygon( pointMat)
def rectangle_to_points(rect):
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
return points
def get_union(pD,pG):
areaA = pD.area();
areaB = pG.area();
return areaA + areaB - get_intersection(pD, pG);
def get_intersection_over_union(pD,pG):
try:
return get_intersection(pD, pG) / get_union(pD, pG);
except:
return 0
def get_intersection(pD,pG):
pInt = pD & pG
if len(pInt) == 0:
return 0
return pInt.area()
def compute_ap(confList, matchList,numGtCare):
correct = 0
AP = 0
if len(confList)>0:
confList = np.array(confList)
matchList = np.array(matchList)
sorted_ind = np.argsort(-confList)
confList = confList[sorted_ind]
matchList = matchList[sorted_ind]
for n in range(len(confList)):
match = matchList[n]
if match:
correct += 1
AP += float(correct)/(n + 1)
if numGtCare>0:
AP /= numGtCare
return AP
def transcription_match(transGt,transDet,specialCharacters=str(r'!?.:,*"()·[]/\''),onlyRemoveFirstLastCharacterGT=True):
if onlyRemoveFirstLastCharacterGT:
#special characters in GT are allowed only at initial or final position
if (transGt==transDet):
return True
if specialCharacters.find(transGt[0])>-1:
if transGt[1:]==transDet:
return True
if specialCharacters.find(transGt[-1])>-1:
if transGt[0:len(transGt)-1]==transDet:
return True
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
if transGt[1:len(transGt)-1]==transDet:
return True
return False
else:
#Special characters are removed from the begining and the end of both Detection and GroundTruth
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
transGt = transGt[1:]
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
transDet = transDet[1:]
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
transGt = transGt[0:len(transGt)-1]
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
transDet = transDet[0:len(transDet)-1]
return transGt == transDet
def include_in_dictionary(transcription):
"""
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = str("'!?.:,*\"()·[]/");
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
if len(transcription) != len(transcription.replace(" ","")) :
return False;
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
return False;
notAllowed = str("×÷·");
range1 = [ ord(u'a'), ord(u'z') ]
range2 = [ ord(u'A'), ord(u'Z') ]
range3 = [ ord(u'À'), ord(u'ƿ') ]
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
range6 = [ ord(u'-'), ord(u'-') ]
for char in transcription :
charCode = ord(char)
if(notAllowed.find(char) != -1):
return False
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
if valid == False:
return False
return True
def include_in_dictionary_transcription(transcription):
"""
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = str("'!?.:,*\"()·[]/");
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
return transcription
perSampleMetrics = {}
matchedSum = 0
det_only_matchedSum = 0
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
numGlobalCareGt = 0;
numGlobalCareDet = 0;
det_only_numGlobalCareGt = 0;
det_only_numGlobalCareDet = 0;
arrGlobalConfidences = [];
arrGlobalMatches = [];
for resFile in gt:
# print('resgt', resFile)
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
if (gtFile is None) :
raise Exception("The file %s is not UTF-8" %resFile)
recall = 0
precision = 0
hmean = 0
detCorrect = 0
detOnlyCorrect = 0
iouMat = np.empty([1,1])
gtPols = []
detPols = []
gtTrans = []
detTrans = []
gtPolPoints = []
detPolPoints = []
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
det_only_gtDontCarePolsNum = []
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
det_only_detDontCarePolsNum = []
detMatchedNums = []
pairs = []
arrSampleConfidences = [];
arrSampleMatch = [];
sampleAP = 0;
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
det_only_dontCare = dontCare = transcription == "###" # ctw1500 and total_text gt have been modified to the same format.
if evaluationParams['LTRB']:
gtRect = Rectangle(*points)
gtPol = rectangle_to_polygon(gtRect)
else:
gtPol = polygon_from_points(points)
gtPols.append(gtPol)
gtPolPoints.append(points)
#On word spotting we will filter some transcriptions with special characters
if evaluationParams['WORD_SPOTTING'] :
if dontCare == False :
if include_in_dictionary(transcription) == False :
dontCare = True
else:
transcription = include_in_dictionary_transcription(transcription)
gtTrans.append(transcription)
if dontCare:
gtDontCarePolsNum.append( len(gtPols)-1 )
if det_only_dontCare:
det_only_gtDontCarePolsNum.append( len(gtPols)-1 )
if resFile in subm:
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents_det(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
if evaluationParams['LTRB']:
detRect = Rectangle(*points)
detPol = rectangle_to_polygon(detRect)
else:
detPol = polygon_from_points(points)
detPols.append(detPol)
detPolPoints.append(points)
detTrans.append(transcription)
if len(gtDontCarePolsNum)>0 :
for dontCarePol in gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol,detPol)
pdDimensions = detPol.area()
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
detDontCarePolsNum.append( len(detPols)-1 )
break
if len(det_only_gtDontCarePolsNum)>0 :
for dontCarePol in det_only_gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol,detPol)
pdDimensions = detPol.area()
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
det_only_detDontCarePolsNum.append( len(detPols)-1 )
break
if len(gtPols)>0 and len(detPols)>0:
#Calculate IoU and precision matrixs
outputShape=[len(gtPols),len(detPols)]
iouMat = np.empty(outputShape)
gtRectMat = np.zeros(len(gtPols),np.int8)
detRectMat = np.zeros(len(detPols),np.int8)
det_only_gtRectMat = np.zeros(len(gtPols),np.int8)
det_only_detRectMat = np.zeros(len(detPols),np.int8)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
#detection matched only if transcription is equal
# det_only_correct = True
# detOnlyCorrect += 1
if evaluationParams['WORD_SPOTTING']:
edd = string_metric.levenshtein(gtTrans[gtNum].upper(), detTrans[detNum].upper())
if edd<=0:
correct = True
else:
correct = False
# correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
else:
try:
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
except: # empty
correct = False
detCorrect += (1 if correct else 0)
if correct:
detMatchedNums.append(detNum)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if det_only_gtRectMat[gtNum] == 0 and det_only_detRectMat[detNum] == 0 and gtNum not in det_only_gtDontCarePolsNum and detNum not in det_only_detDontCarePolsNum:
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
det_only_gtRectMat[gtNum] = 1
det_only_detRectMat[detNum] = 1
#detection matched only if transcription is equal
det_only_correct = True
detOnlyCorrect += 1
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
det_only_numGtCare = (len(gtPols) - len(det_only_gtDontCarePolsNum))
det_only_numDetCare = (len(detPols) - len(det_only_detDontCarePolsNum))
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare >0 else float(1)
else:
recall = float(detCorrect) / numGtCare
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
if det_only_numGtCare == 0:
det_only_recall = float(1)
det_only_precision = float(0) if det_only_numDetCare >0 else float(1)
else:
det_only_recall = float(detOnlyCorrect) / det_only_numGtCare
det_only_precision = 0 if det_only_numDetCare==0 else float(detOnlyCorrect) / det_only_numDetCare
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
det_only_hmean = 0 if (det_only_precision + det_only_recall)==0 else 2.0 * det_only_precision * det_only_recall / (det_only_precision + det_only_recall)
matchedSum += detCorrect
det_only_matchedSum += detOnlyCorrect
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
det_only_numGlobalCareGt += det_only_numGtCare
det_only_numGlobalCareDet += det_only_numDetCare
perSampleMetrics[resFile] = {
'precision':precision,
'recall':recall,
'hmean':hmean,
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
'gtPolPoints':gtPolPoints,
'detPolPoints':detPolPoints,
'gtTrans':gtTrans,
'detTrans':detTrans,
'gtDontCare':gtDontCarePolsNum,
'detDontCare':detDontCarePolsNum,
'evaluationParams': evaluationParams,
}
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
det_only_methodRecall = 0 if det_only_numGlobalCareGt == 0 else float(det_only_matchedSum)/det_only_numGlobalCareGt
det_only_methodPrecision = 0 if det_only_numGlobalCareDet == 0 else float(det_only_matchedSum)/det_only_numGlobalCareDet
det_only_methodHmean = 0 if det_only_methodRecall + det_only_methodPrecision==0 else 2* det_only_methodRecall * det_only_methodPrecision / (det_only_methodRecall + det_only_methodPrecision)
methodMetrics = r"E2E_RESULTS: precision: {}, recall: {}, hmean: {}".format(methodPrecision, methodRecall, methodHmean)
det_only_methodMetrics = r"DETECTION_ONLY_RESULTS: precision: {}, recall: {}, hmean: {}".format(det_only_methodPrecision, det_only_methodRecall, det_only_methodHmean)
resDict = {'calculated':True,'Message':'','e2e_method': methodMetrics,'det_only_method': det_only_methodMetrics,'per_sample': perSampleMetrics}
return resDict;
def text_eval_main(det_file, gt_file, is_word_spotting):
global WORD_SPOTTING
WORD_SPOTTING = is_word_spotting
return rrc_evaluation_funcs.main_evaluation(None,det_file, gt_file, default_evaluation_params,validate_data,evaluate_method)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# encoding=utf8
from collections import namedtuple
from adet.evaluation import rrc_evaluation_funcs_ic15 as rrc_evaluation_funcs
import importlib
import sys
import math
from rapidfuzz import string_metric
WORD_SPOTTING =True
def evaluation_imports():
"""
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
"""
return {
'Polygon':'plg',
'numpy':'np'
}
def default_evaluation_params():
"""
default_evaluation_params: Default parameters to use for the validation and evaluation.
"""
global WORD_SPOTTING
return {
'IOU_CONSTRAINT' :0.5,
'AREA_PRECISION_CONSTRAINT' :0.5,
'WORD_SPOTTING' :WORD_SPOTTING,
'MIN_LENGTH_CARE_WORD' :3,
'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
'CRLF':False, # Lines are delimited by Windows CRLF format
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
}
def validate_data(gtFilePath, submFilePath, evaluationParams):
"""
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
Validates also that there are no missing files in the folder.
If some error detected, the method raises the error
"""
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
#Validate format of GroundTruth
for k in gt:
rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
#Validate format of results
for k in subm:
if (k in gt) == False :
raise Exception("The sample %s not present in GT" %k)
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
"""
Method evaluate_method: evaluate method and returns the results
Results. Dictionary with the following values:
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
"""
for module,alias in evaluation_imports().items():
globals()[alias] = importlib.import_module(module)
def polygon_from_points(points,correctOffset=False):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax
points[2] -= 1
points[4] -= 1
points[5] -= 1
points[7] -= 1
resBoxes=np.empty([1,8],dtype='int32')
resBoxes[0,0]=int(points[0])
resBoxes[0,4]=int(points[1])
resBoxes[0,1]=int(points[2])
resBoxes[0,5]=int(points[3])
resBoxes[0,2]=int(points[4])
resBoxes[0,6]=int(points[5])
resBoxes[0,3]=int(points[6])
resBoxes[0,7]=int(points[7])
pointMat = resBoxes[0].reshape([2,4]).T
return plg.Polygon( pointMat)
def rectangle_to_polygon(rect):
resBoxes=np.empty([1,8],dtype='int32')
resBoxes[0,0]=int(rect.xmin)
resBoxes[0,4]=int(rect.ymax)
resBoxes[0,1]=int(rect.xmin)
resBoxes[0,5]=int(rect.ymin)
resBoxes[0,2]=int(rect.xmax)
resBoxes[0,6]=int(rect.ymin)
resBoxes[0,3]=int(rect.xmax)
resBoxes[0,7]=int(rect.ymax)
pointMat = resBoxes[0].reshape([2,4]).T
return plg.Polygon( pointMat)
def rectangle_to_points(rect):
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
return points
def get_union(pD,pG):
areaA = pD.area();
areaB = pG.area();
return areaA + areaB - get_intersection(pD, pG);
def get_intersection_over_union(pD,pG):
try:
return get_intersection(pD, pG) / get_union(pD, pG);
except:
return 0
def get_intersection(pD,pG):
pInt = pD & pG
if len(pInt) == 0:
return 0
return pInt.area()
def compute_ap(confList, matchList,numGtCare):
correct = 0
AP = 0
if len(confList)>0:
confList = np.array(confList)
matchList = np.array(matchList)
sorted_ind = np.argsort(-confList)
confList = confList[sorted_ind]
matchList = matchList[sorted_ind]
for n in range(len(confList)):
match = matchList[n]
if match:
correct += 1
AP += float(correct)/(n + 1)
if numGtCare>0:
AP /= numGtCare
return AP
def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
if onlyRemoveFirstLastCharacterGT:
#special characters in GT are allowed only at initial or final position
if (transGt==transDet):
return True
if specialCharacters.find(transGt[0])>-1:
if transGt[1:]==transDet:
return True
if specialCharacters.find(transGt[-1])>-1:
if transGt[0:len(transGt)-1]==transDet:
return True
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
if transGt[1:len(transGt)-1]==transDet:
return True
return False
else:
#Special characters are removed from the begining and the end of both Detection and GroundTruth
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
transGt = transGt[1:]
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
transDet = transDet[1:]
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
transGt = transGt[0:len(transGt)-1]
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
transDet = transDet[0:len(transDet)-1]
return transGt == transDet
def include_in_dictionary(transcription):
"""
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = "'!?.:,*\"()·[]/";
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
if len(transcription) != len(transcription.replace(" ","")) :
return False;
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
return False;
notAllowed = "×÷·";
range1 = [ ord(u'a'), ord(u'z') ]
range2 = [ ord(u'A'), ord(u'Z') ]
range3 = [ ord(u'À'), ord(u'ƿ') ]
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
range6 = [ ord(u'-'), ord(u'-') ]
for char in transcription :
charCode = ord(char)
if(notAllowed.find(char) != -1):
return False
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
if valid == False:
return False
return True
def include_in_dictionary_transcription(transcription):
"""
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = "'!?.:,*\"()·[]/";
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
return transcription
perSampleMetrics = {}
matchedSum = 0
det_only_matchedSum = 0
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
numGlobalCareGt = 0;
numGlobalCareDet = 0;
det_only_numGlobalCareGt = 0;
det_only_numGlobalCareDet = 0;
arrGlobalConfidences = [];
arrGlobalMatches = [];
for resFile in gt:
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
if (gtFile is None) :
raise Exception("The file %s is not UTF-8" %resFile)
recall = 0
precision = 0
hmean = 0
detCorrect = 0
detOnlyCorrect = 0
iouMat = np.empty([1,1])
gtPols = []
detPols = []
gtTrans = []
detTrans = []
gtPolPoints = []
detPolPoints = []
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
det_only_gtDontCarePolsNum = []
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
det_only_detDontCarePolsNum = []
detMatchedNums = []
pairs = []
arrSampleConfidences = [];
arrSampleMatch = [];
sampleAP = 0;
evaluationLog = ""
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
# dontCare = transcription == "###"
det_only_dontCare = dontCare = transcription == "###" # ctw1500 and total_text gt have been modified to the same format.
if evaluationParams['LTRB']:
gtRect = Rectangle(*points)
gtPol = rectangle_to_polygon(gtRect)
else:
gtPol = polygon_from_points(points)
gtPols.append(gtPol)
gtPolPoints.append(points)
#On word spotting we will filter some transcriptions with special characters
if evaluationParams['WORD_SPOTTING'] :
if dontCare == False :
if include_in_dictionary(transcription) == False :
dontCare = True
else:
transcription = include_in_dictionary_transcription(transcription)
gtTrans.append(transcription)
if dontCare:
gtDontCarePolsNum.append( len(gtPols)-1 )
if det_only_dontCare:
det_only_gtDontCarePolsNum.append( len(gtPols)-1 )
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
if resFile in subm:
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
if evaluationParams['LTRB']:
detRect = Rectangle(*points)
detPol = rectangle_to_polygon(detRect)
else:
detPol = polygon_from_points(points)
detPols.append(detPol)
detPolPoints.append(points)
detTrans.append(transcription)
if len(gtDontCarePolsNum)>0 :
for dontCarePol in gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol,detPol)
pdDimensions = detPol.area()
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
detDontCarePolsNum.append( len(detPols)-1 )
break
if len(det_only_gtDontCarePolsNum)>0 :
for dontCarePol in det_only_gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol,detPol)
pdDimensions = detPol.area()
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
det_only_detDontCarePolsNum.append( len(detPols)-1 )
break
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
if len(gtPols)>0 and len(detPols)>0:
#Calculate IoU and precision matrixs
outputShape=[len(gtPols),len(detPols)]
iouMat = np.empty(outputShape)
gtRectMat = np.zeros(len(gtPols),np.int8)
detRectMat = np.zeros(len(detPols),np.int8)
det_only_gtRectMat = np.zeros(len(gtPols),np.int8)
det_only_detRectMat = np.zeros(len(detPols),np.int8)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
#detection matched only if transcription is equal
if evaluationParams['WORD_SPOTTING']:
correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
else:
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
detCorrect += (1 if correct else 0)
if correct:
detMatchedNums.append(detNum)
pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if det_only_gtRectMat[gtNum] == 0 and det_only_detRectMat[detNum] == 0 and gtNum not in det_only_gtDontCarePolsNum and detNum not in det_only_detDontCarePolsNum:
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
det_only_gtRectMat[gtNum] = 1
det_only_detRectMat[detNum] = 1
#detection matched only if transcription is equal
det_only_correct = True
detOnlyCorrect += 1
if evaluationParams['CONFIDENCES']:
for detNum in range(len(detPols)):
if detNum not in detDontCarePolsNum :
#we exclude the don't care detections
match = detNum in detMatchedNums
arrSampleConfidences.append(confidencesList[detNum])
arrSampleMatch.append(match)
arrGlobalConfidences.append(confidencesList[detNum]);
arrGlobalMatches.append(match);
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
det_only_numGtCare = (len(gtPols) - len(det_only_gtDontCarePolsNum))
det_only_numDetCare = (len(detPols) - len(det_only_detDontCarePolsNum))
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare >0 else float(1)
sampleAP = precision
else:
recall = float(detCorrect) / numGtCare
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
if evaluationParams['CONFIDENCES']:
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
if det_only_numGtCare == 0:
det_only_recall = float(1)
det_only_precision = float(0) if det_only_numDetCare >0 else float(1)
else:
det_only_recall = float(detOnlyCorrect) / det_only_numGtCare
det_only_precision = 0 if det_only_numDetCare==0 else float(detOnlyCorrect) / det_only_numDetCare
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
det_only_hmean = 0 if (det_only_precision + det_only_recall)==0 else 2.0 * det_only_precision * det_only_recall / (det_only_precision + det_only_recall)
matchedSum += detCorrect
det_only_matchedSum += detOnlyCorrect
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
det_only_numGlobalCareGt += det_only_numGtCare
det_only_numGlobalCareDet += det_only_numDetCare
perSampleMetrics[resFile] = {
'precision':precision,
'recall':recall,
'hmean':hmean,
'pairs':pairs,
'AP':sampleAP,
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
'gtPolPoints':gtPolPoints,
'detPolPoints':detPolPoints,
'gtTrans':gtTrans,
'detTrans':detTrans,
'gtDontCare':gtDontCarePolsNum,
'detDontCare':detDontCarePolsNum,
'evaluationParams': evaluationParams,
'evaluationLog': evaluationLog
}
# Compute AP
AP = 0
if evaluationParams['CONFIDENCES']:
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
det_only_methodRecall = 0 if det_only_numGlobalCareGt == 0 else float(det_only_matchedSum)/det_only_numGlobalCareGt
det_only_methodPrecision = 0 if det_only_numGlobalCareDet == 0 else float(det_only_matchedSum)/det_only_numGlobalCareDet
det_only_methodHmean = 0 if det_only_methodRecall + det_only_methodPrecision==0 else 2* det_only_methodRecall * det_only_methodPrecision / (det_only_methodRecall + det_only_methodPrecision)
methodMetrics = r"E2E_RESULTS: precision: {}, recall: {}, hmean: {}".format(methodPrecision, methodRecall, methodHmean)
det_only_methodMetrics = r"DETECTION_ONLY_RESULTS: precision: {}, recall: {}, hmean: {}".format(det_only_methodPrecision, det_only_methodRecall, det_only_methodHmean)
resDict = {'calculated':True,'Message':'','e2e_method': methodMetrics, 'det_only_method': det_only_methodMetrics, 'per_sample': perSampleMetrics}
return resDict;
def text_eval_main_ic15(det_file, gt_file, is_word_spotting):
global WORD_SPOTTING
WORD_SPOTTING = is_word_spotting
p = {
'g': gt_file,
's': det_file
}
return rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
import contextlib
import copy
import io
import itertools
import json
import logging
import numpy as np
import os
import re
import torch
from collections import OrderedDict
from fvcore.common.file_io import PathManager
from pycocotools.coco import COCO
from detectron2.utils import comm
from detectron2.data import MetadataCatalog
from detectron2.evaluation.evaluator import DatasetEvaluator
import glob
import shutil
from shapely.geometry import Polygon, LinearRing
from adet.evaluation import text_eval_script
from adet.evaluation import text_eval_script_ic15
import zipfile
import pickle
import editdistance
import cv2
class TextEvaluator():
"""
Evaluate text proposals and recognition.
"""
def __init__(self, dataset_name, cfg, distributed, output_dir=None):
self._tasks = ("polygon", "recognition")
self._distributed = distributed
self._output_dir = output_dir
self._cpu_device = torch.device("cpu")
self._logger = logging.getLogger(__name__)
self._metadata = MetadataCatalog.get(dataset_name)
if not hasattr(self._metadata, "json_file"):
raise AttributeError(
f"json_file was not found in MetaDataCatalog for '{dataset_name}'."
)
self.voc_size = cfg.MODEL.TRANSFORMER.VOC_SIZE
self.use_customer_dict = cfg.MODEL.TRANSFORMER.CUSTOM_DICT
if self.voc_size == 37:
self.CTLABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
elif self.voc_size == 96:
self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/',
'0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@',
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V',
'W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l',
'm','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~']
else:
with open(self.use_customer_dict, 'rb') as fp:
self.CTLABELS = pickle.load(fp)
# voc_size includes the unknown class, which is not in self.CTABLES
assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS))
json_file = PathManager.get_local_path(self._metadata.json_file)
with contextlib.redirect_stdout(io.StringIO()):
self._coco_api = COCO(json_file)
self.dataset_name = dataset_name
self.submit = False
# use dataset_name to decide eval_gt_path
self.lexicon_type = cfg.TEST.LEXICON_TYPE
if "totaltext" in dataset_name:
self._text_eval_gt_path = "datasets/evaluation/gt_totaltext.zip"
self._word_spotting = True
self.dataset_name = "totaltext"
elif "ctw1500" in dataset_name:
self._text_eval_gt_path = "datasets/evaluation/gt_ctw1500.zip"
self._word_spotting = False
self.dataset_name = "ctw1500"
elif "ic15" in dataset_name:
self._text_eval_gt_path = "datasets/evaluation/gt_icdar2015.zip"
self._word_spotting = False
self.dataset_name = "ic15"
elif "inversetext" in dataset_name:
self._text_eval_gt_path = "datasets/evaluation/gt_inversetext.zip"
self._word_spotting = False
self.dataset_name = "inversetext"
elif "rects" in dataset_name:
self.submit = True
self._text_eval_gt_path = ""
self.dataset_name = "rects"
else:
raise NotImplementedError
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
for input, output in zip(inputs, outputs):
prediction = {"image_id": input["image_id"]}
instances = output["instances"].to(self._cpu_device)
prediction["instances"] = self.instances_to_coco_json(instances, input)
self._predictions.append(prediction)
def to_eval_format(self, file_path, temp_dir="temp_det_results"):
def fis_ascii(s):
a = (ord(c) < 128 for c in s)
return all(a)
def de_ascii(s):
a = [c for c in s if ord(c) < 128]
outa = ''
for i in a:
outa +=i
return outa
with open(file_path, 'r') as f:
data = json.load(f)
with open('temp_all_det_cors.txt', 'w') as f2:
for ix in range(len(data)):
if data[ix]['score'] > 0.1:
outstr = '{}: '.format(data[ix]['image_id'])
for i in range(len(data[ix]['polys'])):
if "ctw1500" in self.dataset_name:
# there are many boundary points on each side, 'float' type is used for ctw1500
# the original implementation in Adelaidet adopts 'int'
outstr = outstr + str(float(data[ix]['polys'][i][0])) +\
','+str(float(data[ix]['polys'][i][1])) +','
else:
outstr = outstr + str(int(data[ix]['polys'][i][0])) + \
',' + str(int(data[ix]['polys'][i][1])) + ','
ass = str(data[ix]['rec'])
if len(ass)>=0: #
outstr = outstr + str(round(data[ix]['score'], 3)) +',####'+ass+'\n'
f2.writelines(outstr)
f2.close()
dirn = temp_dir
fres = open('temp_all_det_cors.txt', 'r').readlines()
if not os.path.isdir(dirn):
os.mkdir(dirn)
for line in fres:
line = line.strip()
s = line.split(': ')
filename = '{:07d}.txt'.format(int(s[0]))
outName = os.path.join(dirn, filename)
with open(outName, 'a') as fout:
ptr = s[1].strip().split(',####')
score = ptr[0].split(',')[-1]
cors = ','.join(e for e in ptr[0].split(',')[:-1])
fout.writelines(cors+',####'+str(ptr[1])+'\n')
os.remove("temp_all_det_cors.txt")
def sort_detection(self, temp_dir):
origin_file = temp_dir
output_file = "final_"+temp_dir
output_file_full = "full_final_"+temp_dir
if not os.path.isdir(output_file_full):
os.mkdir(output_file_full)
if not os.path.isdir(output_file):
os.mkdir(output_file)
files = glob.glob(origin_file+'*.txt')
files.sort()
if "totaltext" in self.dataset_name:
if not self.lexicon_type == None:
lexicon_path = 'datasets/totaltext/weak_voc_new.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/totaltext/weak_voc_pair_list.txt', 'r')
pairs = dict()
for line in pair_list.readlines():
line=line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word)+1:]
pairs[word] = word_gt
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
elif "ctw1500" in self.dataset_name:
if not self.lexicon_type == None:
lexicon_path = 'datasets/ctw1500/weak_voc_new.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/ctw1500/weak_voc_pair_list.txt', 'r')
pairs = dict()
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
pairs[line.upper()] = line
elif "ic15" in self.dataset_name:
if self.lexicon_type==1:
# generic lexicon
lexicon_path = 'datasets/ic15/GenericVocabulary_new.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/ic15/GenericVocabulary_pair_list.txt', 'r')
pairs = dict()
for line in pair_list.readlines():
line=line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word)+1:]
pairs[word] = word_gt
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
if self.lexicon_type==2:
# weak lexicon
lexicon_path = 'datasets/ic15/ch4_test_vocabulary_new.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/ic15/ch4_test_vocabulary_pair_list.txt', 'r')
pairs = dict()
for line in pair_list.readlines():
line=line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word)+1:]
pairs[word] = word_gt
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
elif "inversetext" in self.dataset_name:
if not self.lexicon_type == None:
lexicon_path = 'datasets/inversetext/inversetext_lexicon.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/inversetext/inversetext_pair_list.txt', 'r')
pairs = dict()
for line in pair_list.readlines():
line=line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word)+1:]
pairs[word] = word_gt
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
def find_match_word(rec_str, pairs, lexicon=None):
rec_str = rec_str.upper()
dist_min = 100
dist_min_pre = 100
match_word = ''
match_dist = 100
for word in lexicon:
word = word.upper()
ed = editdistance.eval(rec_str, word)
length_dist = abs(len(word) - len(rec_str))
dist = ed
if dist<dist_min:
dist_min = dist
match_word = pairs[word]
match_dist = dist
return match_word, match_dist
for i in files:
if "ic15" in self.dataset_name:
out = output_file + 'res_img_' + str(int(i.split('/')[-1].split('.')[0])) + '.txt'
out_full = output_file_full + 'res_img_' + str(int(i.split('/')[-1].split('.')[0])) + '.txt'
if self.lexicon_type==3:
lexicon_path = 'datasets/ic15/new_strong_lexicon/new_voc_img_' + str(int(i.split('/')[-1].split('.')[0])) + '.txt'
lexicon_fid=open(lexicon_path, 'r')
pair_list = open('datasets/ic15/new_strong_lexicon/pair_voc_img_' + str(int(i.split('/')[-1].split('.')[0])) + '.txt')
pairs = dict()
for line in pair_list.readlines():
line=line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word)+1:]
pairs[word] = word_gt
lexicon_fid=open(lexicon_path, 'r')
lexicon=[]
for line in lexicon_fid.readlines():
line=line.strip()
lexicon.append(line)
else:
out = i.replace(origin_file, output_file)
out_full = i.replace(origin_file, output_file_full)
fin = open(i, 'r').readlines()
fout = open(out, 'w')
fout_full = open(out_full, 'w')
for iline, line in enumerate(fin):
ptr = line.strip().split(',####')
rec = ptr[1]
cors = ptr[0].split(',')
assert(len(cors) %2 == 0), 'cors invalid.'
if "ctw1500" in self.dataset_name:
pts = [(float(cors[j]), float(cors[j+1])) for j in range(0,len(cors),2)] # int->float
else:
pts = [(int(cors[j]), int(cors[j + 1])) for j in range(0, len(cors), 2)]
try:
pgt = Polygon(pts)
except Exception as e:
print(e)
print('An invalid detection in {} line {} is removed ... '.format(i, iline))
continue
if not pgt.is_valid:
print('An invalid detection in {} line {} is removed ... '.format(i, iline))
continue
pRing = LinearRing(pts)
if not "ic15" in self.dataset_name:
if pRing.is_ccw:
pts.reverse()
outstr = ''
for ipt in pts:
if "ctw1500" in self.dataset_name:
outstr += (str(float(ipt[0]))+','+ str(float(ipt[1]))+',') # int->float
else:
outstr += (str(int(ipt[0])) + ',' + str(int(ipt[1])) + ',')
outstr = outstr[:-1]
pts = outstr
if "ic15" in self.dataset_name:
outstr = outstr + ',' + rec
else:
outstr = outstr + ',####' + rec
fout.writelines(outstr+'\n')
if self.lexicon_type is None:
rec_full = rec
else:
match_word, match_dist = find_match_word(rec,pairs,lexicon)
if match_dist<1.5:
rec_full = match_word
if "ic15" in self.dataset_name:
pts = pts + ',' + rec_full
else:
pts = pts + ',####' + rec_full
fout_full.writelines(pts+'\n')
fout.close()
fout_full.close()
def zipdir(path, ziph):
# ziph is zipfile handle
for root, dirs, files in os.walk(path):
for file in files:
ziph.write(os.path.join(root, file))
if "ic15" in self.dataset_name:
os.system('zip -r -q -j '+'det.zip'+' '+output_file+'/*')
os.system('zip -r -q -j '+'det_full.zip'+' '+output_file_full+'/*')
shutil.rmtree(origin_file)
shutil.rmtree(output_file)
shutil.rmtree(output_file_full)
return "det.zip", "det_full.zip"
else:
os.chdir(output_file)
zipf = zipfile.ZipFile('../det.zip', 'w', zipfile.ZIP_DEFLATED)
zipdir('./', zipf)
zipf.close()
os.chdir("../")
os.chdir(output_file_full)
zipf_full = zipfile.ZipFile('../det_full.zip', 'w', zipfile.ZIP_DEFLATED)
zipdir('./', zipf_full)
zipf_full.close()
os.chdir("../")
# clean temp files
shutil.rmtree(origin_file)
shutil.rmtree(output_file)
shutil.rmtree(output_file_full)
return "det.zip", "det_full.zip"
def evaluate_with_official_code(self, result_path, gt_path):
if "ic15" in self.dataset_name:
return text_eval_script_ic15.text_eval_main_ic15(det_file=result_path, gt_file=gt_path, is_word_spotting=self._word_spotting)
else:
return text_eval_script.text_eval_main(det_file=result_path, gt_file=gt_path, is_word_spotting=self._word_spotting)
def evaluate(self):
if self._distributed:
comm.synchronize()
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
if not comm.is_main_process():
return {}
else:
predictions = self._predictions
if len(predictions) == 0:
self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
return {}
PathManager.mkdirs(self._output_dir)
if self.submit:
file_path = os.path.join(self._output_dir, self.dataset_name+"_submit.txt")
self._logger.info("Saving results to {}".format(file_path))
with PathManager.open(file_path, "w") as f:
for prediction in predictions:
write_id = "{:06d}".format(prediction["image_id"]+1)
write_img_name = "test_"+write_id+'.jpg\n'
f.write(write_img_name)
if len(prediction["instances"]) > 0:
for inst in prediction["instances"]:
write_poly, write_text = inst["polys"], inst["rec"]
if write_text == '':
continue
if not LinearRing(write_poly).is_ccw:
write_poly.reverse()
write_poly = np.array(write_poly).reshape(-1).tolist()
write_poly = ','.join(list(map(str,write_poly)))
f.write(write_poly+','+write_text+'\n')
f.flush()
self._logger.info("Ready to submit results from {}".format(file_path))
else:
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
file_path = os.path.join(self._output_dir, "text_results.json")
self._logger.info("Saving results to {}".format(file_path))
with PathManager.open(file_path, "w") as f:
f.write(json.dumps(coco_results))
f.flush()
self._results = OrderedDict()
# eval text
if not self._text_eval_gt_path:
return copy.deepcopy(self._results)
temp_dir = "temp_det_results/"
self.to_eval_format(file_path, temp_dir)
result_path, result_path_full = self.sort_detection(temp_dir)
text_result = self.evaluate_with_official_code(result_path, self._text_eval_gt_path) # None
text_result["e2e_method"] = "None-" + text_result["e2e_method"]
dict_lexicon = {"1": "Generic", "2": "Weak", "3": "Strong"}
text_result_full = self.evaluate_with_official_code(result_path_full, self._text_eval_gt_path) # with lexicon
text_result_full["e2e_method"] = dict_lexicon[str(self.lexicon_type)] + "-" + text_result_full["e2e_method"]
os.remove(result_path)
os.remove(result_path_full)
# parse
template = "(\S+): (\S+): (\S+), (\S+): (\S+), (\S+): (\S+)"
result = text_result["det_only_method"]
groups = re.match(template, result).groups()
self._results[groups[0]] = {groups[i*2+1]: float(groups[(i+1)*2]) for i in range(3)}
result = text_result["e2e_method"]
groups = re.match(template, result).groups()
self._results[groups[0]] = {groups[i*2+1]: float(groups[(i+1)*2]) for i in range(3)}
result = text_result_full["e2e_method"]
groups = re.match(template, result).groups()
self._results[groups[0]] = {groups[i*2+1]: float(groups[(i+1)*2]) for i in range(3)}
return copy.deepcopy(self._results)
def instances_to_coco_json(self, instances, inputs):
img_id = inputs["image_id"]
width = inputs['width']
height = inputs['height']
num_instances = len(instances)
if num_instances == 0:
return []
scores = instances.scores.tolist()
pnts = instances.bd.numpy()
recs = instances.recs.numpy()
results = []
for pnt, rec, score in zip(pnts, recs, scores):
poly = self.pnt_to_polygon(pnt)
if 'ic15' in self.dataset_name or 'rects' in self.dataset_name:
poly = polygon2rbox(poly, height, width)
s = self.ctc_decode(rec)
result = {
"image_id": img_id,
"category_id": 1,
"polys": poly,
"rec": s,
"score": score,
}
results.append(result)
return results
def pnt_to_polygon(self, ctrl_pnt):
ctrl_pnt = np.hsplit(ctrl_pnt, 2)
ctrl_pnt = np.vstack([ctrl_pnt[0], ctrl_pnt[1][::-1]])
return ctrl_pnt.tolist()
def ctc_decode(self, rec):
last_char = '###'
s = ''
for c in rec:
c = int(c)
if c < self.voc_size - 1:
if last_char != c:
if self.voc_size == 37 or self.voc_size == 96:
s += self.CTLABELS[c]
last_char = c
else:
s += str(chr(self.CTLABELS[c]))
last_char = c
else:
last_char = '###'
return s
def polygon2rbox(polygon, image_height, image_width):
poly = np.array(polygon).reshape((-1, 2)).astype(np.float32)
rect = cv2.minAreaRect(poly)
corners = cv2.boxPoints(rect)
corners = np.array(corners, dtype="int")
pts = get_tight_rect(corners, 0, 0, image_height, image_width, 1)
pts = np.array(pts).reshape(-1,2)
pts = pts.tolist()
return pts
def get_tight_rect(points, start_x, start_y, image_height, image_width, scale):
points = list(points)
ps = sorted(points, key=lambda x: x[0])
if ps[1][1] > ps[0][1]:
px1 = ps[0][0] * scale + start_x
py1 = ps[0][1] * scale + start_y
px4 = ps[1][0] * scale + start_x
py4 = ps[1][1] * scale + start_y
else:
px1 = ps[1][0] * scale + start_x
py1 = ps[1][1] * scale + start_y
px4 = ps[0][0] * scale + start_x
py4 = ps[0][1] * scale + start_y
if ps[3][1] > ps[2][1]:
px2 = ps[2][0] * scale + start_x
py2 = ps[2][1] * scale + start_y
px3 = ps[3][0] * scale + start_x
py3 = ps[3][1] * scale + start_y
else:
px2 = ps[3][0] * scale + start_x
py2 = ps[3][1] * scale + start_y
px3 = ps[2][0] * scale + start_x
py3 = ps[2][1] * scale + start_y
px1 = min(max(px1, 1), image_width - 1)
px2 = min(max(px2, 1), image_width - 1)
px3 = min(max(px3, 1), image_width - 1)
px4 = min(max(px4, 1), image_width - 1)
py1 = min(max(py1, 1), image_height - 1)
py2 = min(max(py2, 1), image_height - 1)
py3 = min(max(py3, 1), image_height - 1)
py4 = min(max(py4, 1), image_height - 1)
return [px1, py1, px2, py2, px3, py3, px4, py4]
from .ms_deform_attn import MSDeformAttn
__all__ = [k for k in globals().keys() if not k.startswith("_")]
\ No newline at end of file
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment