Commit 5f77267a authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents dc67dc3e 967f0676
This diff is collapsed.
...@@ -8,6 +8,8 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w ...@@ -8,6 +8,8 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
### Recent Update ### Recent Update
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- Added KIE mode, for [detection + identification + keyword extraction] labeling.
- 2022.01:(by [PeterH0323](https://github.com/peterh0323) - 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference - Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference
- 2021.11.17: - 2021.11.17:
...@@ -72,7 +74,8 @@ PPOCRLabel ...@@ -72,7 +74,8 @@ PPOCRLabel
```bash ```bash
pip3 install PPOCRLabel pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32 pip3 install opencv-contrib-python-headless==4.2.0.32
PPOCRLabel # run PPOCRLabel # [Normal mode] for [detection + recognition] labeling
PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
``` ```
#### 1.2.2 Build and Install the Whl Package Locally #### 1.2.2 Build and Install the Whl Package Locally
...@@ -87,7 +90,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl ...@@ -87,7 +90,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```bash ```bash
cd ./PPOCRLabel # Switch to the PPOCRLabel directory cd ./PPOCRLabel # Switch to the PPOCRLabel directory
python PPOCRLabel.py python PPOCRLabel.py # [Normal mode] for [detection + recognition] labeling
python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
``` ```
...@@ -198,21 +202,31 @@ For some data that are difficult to recognize, the recognition results will not ...@@ -198,21 +202,31 @@ For some data that are difficult to recognize, the recognition results will not
- Enter the following command in the terminal to execute the dataset division script: - Enter the following command in the terminal to execute the dataset division script:
``` ```
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
``` ```
Parameter Description: Parameter Description:
- `trainValTestRatio` is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is `6:2:2` - `trainValTestRatio` is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is `6:2:2`
- `labelRootPath` is the storage path of the dataset labeled by PPOCRLabel, the default is `../train_data/label` - `datasetRootPath` is the storage path of the complete dataset labeled by PPOCRLabel. The default path is `PaddleOCR/train_data` .
```
- `detRootPath` is the path where the text detection dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/det` |-train_data
|-crop_img
- `recRootPath` is the path where the character recognition dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/rec` |- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 Error message ### 3.6 Error message
- If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated. - If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated.
......
...@@ -8,6 +8,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P ...@@ -8,6 +8,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
#### 近期更新 #### 近期更新
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- 新增:KIE 功能,用于打【检测+识别+关键字提取】的标签
- 2022.01:(by [PeterH0323](https://github.com/peterh0323) - 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题 - 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题
- 2021.11.17: - 2021.11.17:
...@@ -70,7 +72,8 @@ PPOCRLabel --lang ch ...@@ -70,7 +72,8 @@ PPOCRLabel --lang ch
```bash ```bash
pip3 install PPOCRLabel pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添加"-i https://mirror.baidu.com/pypi/simple" pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添加"-i https://mirror.baidu.com/pypi/simple"
PPOCRLabel --lang ch # 启动 PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
``` ```
> 如果上述安装出现问题,可以参考3.6节 错误提示 > 如果上述安装出现问题,可以参考3.6节 错误提示
...@@ -89,7 +92,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu. ...@@ -89,7 +92,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.
```bash ```bash
cd ./PPOCRLabel # 切换到PPOCRLabel目录 cd ./PPOCRLabel # 切换到PPOCRLabel目录
python PPOCRLabel.py --lang ch python PPOCRLabel.py --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
``` ```
...@@ -185,19 +189,29 @@ PPOCRLabel支持三种导出方式: ...@@ -185,19 +189,29 @@ PPOCRLabel支持三种导出方式:
``` ```
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下 cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
``` ```
参数说明: 参数说明:
- `trainValTestRatio` 是训练集、验证集、测试集的图像数量划分比例,根据实际情况设定,默认是`6:2:2` - `trainValTestRatio` 是训练集、验证集、测试集的图像数量划分比例,根据实际情况设定,默认是`6:2:2`
- `labelRootPath` 是PPOCRLabel标注的数据集存放路径,默认是`../train_data/label` - `datasetRootPath` 是PPOCRLabel标注的完整数据集存放路径。默认路径是 `PaddleOCR/train_data` 分割数据集前应有如下结构:
```
- `detRootPath` 是根据PPOCRLabel标注的数据集划分后的文本检测数据集存放的路径,默认是`../train_data/det ` |-train_data
|-crop_img
- `recRootPath` 是根据PPOCRLabel标注的数据集划分后的字符识别数据集存放的路径,默认是`../train_data/rec` |- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 错误提示 ### 3.6 错误提示
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。 - 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
......
...@@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag): ...@@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath return flagAbsPath
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集 # 按照指定的比例划分训练集、验证集、测试集
labelPath = os.path.join(root, dir) dataAbsPath = os.path.abspath(root)
labelAbsPath = os.path.abspath(labelPath)
if flag == "det": if flag == "det":
labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName) labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
elif flag == "rec": elif flag == "rec":
labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName) labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
labelFileRead = open(labelFilePath, "r", encoding="UTF-8") labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines() labelFileContent = labelFileRead.readlines()
...@@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, ...@@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath,
imageName = os.path.basename(imageRelativePath) imageName = os.path.basename(imageRelativePath)
if flag == "det": if flag == "det":
imagePath = os.path.join(labelAbsPath, imageName) imagePath = os.path.join(dataAbsPath, imageName)
elif flag == "rec": elif flag == "rec":
imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
# 按预设的比例划分训练集、验证集、测试集 # 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":") trainValTestRatio = args.trainValTestRatio.split(":")
...@@ -90,15 +89,20 @@ def genDetRecTrainVal(args): ...@@ -90,15 +89,20 @@ def genDetRecTrainVal(args):
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8") recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8") recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
for root, dirs, files in os.walk(args.labelRootPath): splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs: for dir in dirs:
splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, if dir == 'crop_img':
detTestTxt, "det") splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, recTestTxt, "rec")
recTestTxt, "rec") else:
continue
break break
if __name__ == "__main__": if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集 # 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注, # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
...@@ -110,9 +114,9 @@ if __name__ == "__main__": ...@@ -110,9 +114,9 @@ if __name__ == "__main__":
default="6:2:2", default="6:2:2",
help="ratio of trainset:valset:testset") help="ratio of trainset:valset:testset")
parser.add_argument( parser.add_argument(
"--labelRootPath", "--datasetRootPath",
type=str, type=str,
default="../train_data/label", default="../train_data/",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..." help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
) )
parser.add_argument( parser.add_argument(
......
...@@ -783,7 +783,7 @@ class Canvas(QWidget): ...@@ -783,7 +783,7 @@ class Canvas(QWidget):
points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)] points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)]
return True in map(self.outOfPixmap, points) return True in map(self.outOfPixmap, points)
def setLastLabel(self, text, line_color = None, fill_color = None): def setLastLabel(self, text, line_color=None, fill_color=None, key_cls=None):
assert text assert text
self.shapes[-1].label = text self.shapes[-1].label = text
if line_color: if line_color:
...@@ -791,6 +791,10 @@ class Canvas(QWidget): ...@@ -791,6 +791,10 @@ class Canvas(QWidget):
if fill_color: if fill_color:
self.shapes[-1].fill_color = fill_color self.shapes[-1].fill_color = fill_color
if key_cls:
self.shapes[-1].key_cls = key_cls
self.storeShapes() self.storeShapes()
return self.shapes[-1] return self.shapes[-1]
......
import re
from PyQt5 import QtCore
from PyQt5 import QtGui
from PyQt5 import QtWidgets
from PyQt5.Qt import QT_VERSION_STR
from libs.utils import newIcon, labelValidator
QT5 = QT_VERSION_STR[0] == '5'
# TODO(unknown):
# - Calculate optimal position so as not to go out of screen area.
class KeyQLineEdit(QtWidgets.QLineEdit):
def setListWidget(self, list_widget):
self.list_widget = list_widget
def keyPressEvent(self, e):
if e.key() in [QtCore.Qt.Key_Up, QtCore.Qt.Key_Down]:
self.list_widget.keyPressEvent(e)
else:
super(KeyQLineEdit, self).keyPressEvent(e)
class KeyDialog(QtWidgets.QDialog):
def __init__(
self,
text="Enter object label",
parent=None,
labels=None,
sort_labels=True,
show_text_field=True,
completion="startswith",
fit_to_content=None,
flags=None,
):
if fit_to_content is None:
fit_to_content = {"row": False, "column": True}
self._fit_to_content = fit_to_content
super(KeyDialog, self).__init__(parent)
self.edit = KeyQLineEdit()
self.edit.setPlaceholderText(text)
self.edit.setValidator(labelValidator())
self.edit.editingFinished.connect(self.postProcess)
if flags:
self.edit.textChanged.connect(self.updateFlags)
layout = QtWidgets.QVBoxLayout()
if show_text_field:
layout_edit = QtWidgets.QHBoxLayout()
layout_edit.addWidget(self.edit, 6)
layout.addLayout(layout_edit)
# buttons
self.buttonBox = bb = QtWidgets.QDialogButtonBox(
QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
QtCore.Qt.Horizontal,
self,
)
bb.button(bb.Ok).setIcon(newIcon("done"))
bb.button(bb.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
# label_list
self.labelList = QtWidgets.QListWidget()
if self._fit_to_content["row"]:
self.labelList.setHorizontalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
if self._fit_to_content["column"]:
self.labelList.setVerticalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
self._sort_labels = sort_labels
if labels:
self.labelList.addItems(labels)
if self._sort_labels:
self.labelList.sortItems()
else:
self.labelList.setDragDropMode(
QtWidgets.QAbstractItemView.InternalMove
)
self.labelList.currentItemChanged.connect(self.labelSelected)
self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
self.edit.setListWidget(self.labelList)
layout.addWidget(self.labelList)
# label_flags
if flags is None:
flags = {}
self._flags = flags
self.flagsLayout = QtWidgets.QVBoxLayout()
self.resetFlags()
layout.addItem(self.flagsLayout)
self.edit.textChanged.connect(self.updateFlags)
self.setLayout(layout)
# completion
completer = QtWidgets.QCompleter()
if not QT5 and completion != "startswith":
completion = "startswith"
if completion == "startswith":
completer.setCompletionMode(QtWidgets.QCompleter.InlineCompletion)
# Default settings.
# completer.setFilterMode(QtCore.Qt.MatchStartsWith)
elif completion == "contains":
completer.setCompletionMode(QtWidgets.QCompleter.PopupCompletion)
completer.setFilterMode(QtCore.Qt.MatchContains)
else:
raise ValueError("Unsupported completion: {}".format(completion))
completer.setModel(self.labelList.model())
self.edit.setCompleter(completer)
def addLabelHistory(self, label):
if self.labelList.findItems(label, QtCore.Qt.MatchExactly):
return
self.labelList.addItem(label)
if self._sort_labels:
self.labelList.sortItems()
def labelSelected(self, item):
self.edit.setText(item.text())
def validate(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
if text:
self.accept()
def labelDoubleClicked(self, item):
self.validate()
def postProcess(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
self.edit.setText(text)
def updateFlags(self, label_new):
# keep state of shared flags
flags_old = self.getFlags()
flags_new = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label_new):
for key in keys:
flags_new[key] = flags_old.get(key, False)
self.setFlags(flags_new)
def deleteFlags(self):
for i in reversed(range(self.flagsLayout.count())):
item = self.flagsLayout.itemAt(i).widget()
self.flagsLayout.removeWidget(item)
item.setParent(None)
def resetFlags(self, label=""):
flags = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label):
for key in keys:
flags[key] = False
self.setFlags(flags)
def setFlags(self, flags):
self.deleteFlags()
for key in flags:
item = QtWidgets.QCheckBox(key, self)
item.setChecked(flags[key])
self.flagsLayout.addWidget(item)
item.show()
def getFlags(self):
flags = {}
for i in range(self.flagsLayout.count()):
item = self.flagsLayout.itemAt(i).widget()
flags[item.text()] = item.isChecked()
return flags
def popUp(self, text=None, move=True, flags=None):
if self._fit_to_content["row"]:
self.labelList.setMinimumHeight(
self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
)
if self._fit_to_content["column"]:
self.labelList.setMinimumWidth(
self.labelList.sizeHintForColumn(0) + 2
)
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
self.edit.setText(text)
self.edit.setSelection(0, len(text))
items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
if items:
if len(items) != 1:
self.labelList.setCurrentItem(items[0])
row = self.labelList.row(items[0])
self.edit.completer().setCurrentRow(row)
self.edit.setFocus(QtCore.Qt.PopupFocusReason)
if move:
self.move(QtGui.QCursor.pos())
if self.exec_():
return self.edit.text(), self.getFlags()
else:
return None, None
import PIL.Image
import numpy as np
def rgb2hsv(rgb):
# type: (np.ndarray) -> np.ndarray
"""Convert rgb to hsv.
Parameters
----------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Input rgb image.
Returns
-------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Output hsv image.
"""
hsv = PIL.Image.fromarray(rgb, mode="RGB")
hsv = hsv.convert("HSV")
hsv = np.array(hsv)
return hsv
def hsv2rgb(hsv):
# type: (np.ndarray) -> np.ndarray
"""Convert hsv to rgb.
Parameters
----------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Input hsv image.
Returns
-------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Output rgb image.
"""
rgb = PIL.Image.fromarray(hsv, mode="HSV")
rgb = rgb.convert("RGB")
rgb = np.array(rgb)
return rgb
def label_colormap(n_label=256, value=None):
"""Label colormap.
Parameters
----------
n_label: int
Number of labels (default: 256).
value: float or int
Value scale or value of label color in HSV space.
Returns
-------
cmap: numpy.ndarray, (N, 3), numpy.uint8
Label id to colormap.
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
cmap = np.zeros((n_label, 3), dtype=np.uint8)
for i in range(0, n_label):
id = i
r, g, b = 0, 0, 0
for j in range(0, 8):
r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
if value is not None:
hsv = rgb2hsv(cmap.reshape(1, -1, 3))
if isinstance(value, float):
hsv[:, 1:, 2] = hsv[:, 1:, 2].astype(float) * value
else:
assert isinstance(value, int)
hsv[:, 1:, 2] = value
cmap = hsv2rgb(hsv).reshape(-1, 3)
return cmap
...@@ -46,12 +46,13 @@ class Shape(object): ...@@ -46,12 +46,13 @@ class Shape(object):
point_size = 8 point_size = 8
scale = 1.0 scale = 1.0
def __init__(self, label=None, line_color=None, difficult=False, paintLabel=False): def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
self.label = label self.label = label
self.points = [] self.points = []
self.fill = False self.fill = False
self.selected = False self.selected = False
self.difficult = difficult self.difficult = difficult
self.key_cls = key_cls
self.paintLabel = paintLabel self.paintLabel = paintLabel
self.locked = False self.locked = False
self.direction = 0 self.direction = 0
...@@ -224,6 +225,7 @@ class Shape(object): ...@@ -224,6 +225,7 @@ class Shape(object):
if self.fill_color != Shape.fill_color: if self.fill_color != Shape.fill_color:
shape.fill_color = self.fill_color shape.fill_color = self.fill_color
shape.difficult = self.difficult shape.difficult = self.difficult
shape.key_cls = self.key_cls
return shape return shape
def __len__(self): def __len__(self):
......
# -*- encoding: utf-8 -*-
from PyQt5.QtCore import Qt
from PyQt5 import QtWidgets
class EscapableQListWidget(QtWidgets.QListWidget):
def keyPressEvent(self, event):
super(EscapableQListWidget, self).keyPressEvent(event)
if event.key() == Qt.Key_Escape:
self.clearSelection()
class UniqueLabelQListWidget(EscapableQListWidget):
def mousePressEvent(self, event):
super(UniqueLabelQListWidget, self).mousePressEvent(event)
if not self.indexAt(event.pos()).isValid():
self.clearSelection()
def findItemsByLabel(self, label, get_row=False):
items = []
for row in range(self.count()):
item = self.item(row)
if item.data(Qt.UserRole) == label:
items.append(item)
if get_row:
return row
return items
def createItemFromLabel(self, label):
item = QtWidgets.QListWidgetItem()
item.setData(Qt.UserRole, label)
return item
def setItemLabel(self, item, label, color=None):
qlabel = QtWidgets.QLabel()
if color is None:
qlabel.setText(f"{label}")
else:
qlabel.setText('<font color="#{:02x}{:02x}{:02x}">●</font> {} '.format(*color, label))
qlabel.setAlignment(Qt.AlignBottom)
item.setSizeHint(qlabel.sizeHint())
self.setItemWidget(item, qlabel)
...@@ -10,30 +10,26 @@ ...@@ -10,30 +10,26 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF # SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
from math import sqrt
from libs.ustr import ustr
import hashlib import hashlib
import os
import re import re
import sys import sys
from math import sqrt
import cv2 import cv2
import numpy as np import numpy as np
import os from PyQt5.QtCore import QRegExp, QT_VERSION_STR
from PyQt5.QtGui import QIcon, QRegExpValidator, QColor
from PyQt5.QtWidgets import QPushButton, QAction, QMenu
from libs.ustr import ustr
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径 __dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
__iconpath__ = os.path.abspath(os.path.join(__dir__, '../resources/icons')) __iconpath__ = os.path.abspath(os.path.join(__dir__, '../resources/icons'))
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
def newIcon(icon, iconSize=None): def newIcon(icon, iconSize=None):
if iconSize is not None: if iconSize is not None:
return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize,iconSize)) return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize, iconSize))
else: else:
return QIcon(__iconpath__ + "/" + icon + ".png") return QIcon(__iconpath__ + "/" + icon + ".png")
...@@ -105,24 +101,25 @@ def generateColorByText(text): ...@@ -105,24 +101,25 @@ def generateColorByText(text):
s = ustr(text) s = ustr(text)
hashCode = int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16) hashCode = int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16)
r = int((hashCode / 255) % 255) r = int((hashCode / 255) % 255)
g = int((hashCode / 65025) % 255) g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255) b = int((hashCode / 16581375) % 255)
return QColor(r, g, b, 100) return QColor(r, g, b, 100)
def have_qstring(): def have_qstring():
'''p3/qt5 get rid of QString wrapper as py3 has native unicode str type''' '''p3/qt5 get rid of QString wrapper as py3 has native unicode str type'''
return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith('5.')) return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith('5.'))
def util_qt_strlistclass():
return QStringList if have_qstring() else list
def natural_sort(list, key=lambda s:s): def natural_sort(list, key=lambda s: s):
""" """
Sort the list into natural alphanumeric order. Sort the list into natural alphanumeric order.
""" """
def get_alphanum_key_func(key): def get_alphanum_key_func(key):
convert = lambda text: int(text) if text.isdigit() else text convert = lambda text: int(text) if text.isdigit() else text
return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))] return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
sort_key = get_alphanum_key_func(key) sort_key = get_alphanum_key_func(key)
list.sort(key=sort_key) list.sort(key=sort_key)
...@@ -133,8 +130,8 @@ def get_rotate_crop_image(img, points): ...@@ -133,8 +130,8 @@ def get_rotate_crop_image(img, points):
d = 0.0 d = 0.0
for index in range(-1, 3): for index in range(-1, 3):
d += -0.5 * (points[index + 1][1] + points[index][1]) * ( d += -0.5 * (points[index + 1][1] + points[index][1]) * (
points[index + 1][0] - points[index][0]) points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise if d < 0: # counterclockwise
tmp = np.array(points) tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1] points[1], points[3] = tmp[3], tmp[1]
...@@ -163,10 +160,11 @@ def get_rotate_crop_image(img, points): ...@@ -163,10 +160,11 @@ def get_rotate_crop_image(img, points):
except Exception as e: except Exception as e:
print(e) print(e)
def stepsInfo(lang='en'): def stepsInfo(lang='en'):
if lang == 'ch': if lang == 'ch':
msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \ msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \
"2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n"\ "2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n" \
"3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n" \ "3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n" \
"4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动" \ "4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动" \
"绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n" \ "绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n" \
...@@ -181,25 +179,26 @@ def stepsInfo(lang='en'): ...@@ -181,25 +179,26 @@ def stepsInfo(lang='en'):
else: else:
msg = "1. Build and launch using the instructions above.\n" \ msg = "1. Build and launch using the instructions above.\n" \
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\ "2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n" \
"3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name."\ "3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name." \
"4. Create Box:\n"\ "4. Create Box:\n" \
"4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n"\ "4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n" \
"4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n"\ "4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n" \
"5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n"\ "5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n" \
"6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n"\ "6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n" \
"7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n"\ "7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n" \
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\ "8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n" \
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\ "9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n" \
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\ "10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n" \
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n" " Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
return msg return msg
def keysInfo(lang='en'): def keysInfo(lang='en'):
if lang == 'ch': if lang == 'ch':
msg = "快捷键\t\t\t说明\n" \ msg = "快捷键\t\t\t说明\n" \
"———————————————————————\n"\ "———————————————————————\n" \
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \ "Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
"W\t\t\t新建矩形框\n" \ "W\t\t\t新建矩形框\n" \
"Q\t\t\t新建四点框\n" \ "Q\t\t\t新建四点框\n" \
...@@ -223,17 +222,17 @@ def keysInfo(lang='en'): ...@@ -223,17 +222,17 @@ def keysInfo(lang='en'):
"———————————————————————\n" \ "———————————————————————\n" \
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \ "Ctrl + shift + R\t\tRe-recognize all the labels\n" \
"\t\t\tof the current image\n" \ "\t\t\tof the current image\n" \
"\n"\ "\n" \
"W\t\t\tCreate a rect box\n" \ "W\t\t\tCreate a rect box\n" \
"Q\t\t\tCreate a four-points box\n" \ "Q\t\t\tCreate a four-points box\n" \
"Ctrl + E\t\tEdit label of the selected box\n" \ "Ctrl + E\t\tEdit label of the selected box\n" \
"Ctrl + R\t\tRe-recognize the selected box\n" \ "Ctrl + R\t\tRe-recognize the selected box\n" \
"Ctrl + C\t\tCopy and paste the selected\n" \ "Ctrl + C\t\tCopy and paste the selected\n" \
"\t\t\tbox\n" \ "\t\t\tbox\n" \
"\n"\ "\n" \
"Ctrl + Left Mouse\tMulti select the label\n" \ "Ctrl + Left Mouse\tMulti select the label\n" \
"Button\t\t\tbox\n" \ "Button\t\t\tbox\n" \
"\n"\ "\n" \
"Backspace\t\tDelete the selected box\n" \ "Backspace\t\tDelete the selected box\n" \
"Ctrl + V\t\tCheck image\n" \ "Ctrl + V\t\tCheck image\n" \
"Ctrl + Shift + d\tDelete image\n" \ "Ctrl + Shift + d\tDelete image\n" \
...@@ -245,4 +244,4 @@ def keysInfo(lang='en'): ...@@ -245,4 +244,4 @@ def keysInfo(lang='en'):
"———————————————————————\n" \ "———————————————————————\n" \
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key" "Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
return msg return msg
\ No newline at end of file
...@@ -106,4 +106,7 @@ undo=Undo ...@@ -106,4 +106,7 @@ undo=Undo
undoLastPoint=Undo Last Point undoLastPoint=Undo Last Point
autoSaveMode=Auto Export Label Mode autoSaveMode=Auto Export Label Mode
lockBox=Lock selected box/Unlock all box lockBox=Lock selected box/Unlock all box
lockBoxDetail=Lock selected box/Unlock all box lockBoxDetail=Lock selected box/Unlock all box
\ No newline at end of file keyListTitle=Key List
keyDialogTip=Enter object label
keyChange=Change Box Key
...@@ -107,3 +107,6 @@ undoLastPoint=撤销上个点 ...@@ -107,3 +107,6 @@ undoLastPoint=撤销上个点
autoSaveMode=自动导出标记结果 autoSaveMode=自动导出标记结果
lockBox=锁定框/解除锁定框 lockBox=锁定框/解除锁定框
lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态 lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态
keyListTitle=关键词列表
keyDialogTip=请输入类型名称
keyChange=更改Box关键字类别
\ No newline at end of file
...@@ -152,7 +152,7 @@ For a new language request, please refer to [Guideline for new language_requests ...@@ -152,7 +152,7 @@ For a new language request, please refer to [Guideline for new language_requests
[1] PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection, detection frame correction and CRNN text recognition. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941). [1] PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection, detection frame correction and CRNN text recognition. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941).
[2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (arXiv link is coming soon). [2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (https://arxiv.org/abs/2109.03144).
...@@ -181,16 +181,11 @@ For a new language request, please refer to [Guideline for new language_requests ...@@ -181,16 +181,11 @@ For a new language request, please refer to [Guideline for new language_requests
<a name="language_requests"></a> <a name="language_requests"></a>
## Guideline for New Language Requests ## Guideline for New Language Requests
If you want to request a new language support, a PR with 2 following files are needed: If you want to request a new language support, a PR with 1 following files are needed:
1. In folder [ppocr/utils/dict](./ppocr/utils/dict), 1. In folder [ppocr/utils/dict](./ppocr/utils/dict),
it is necessary to submit the dict text to this path and name it with `{language}_dict.txt` that contains a list of all characters. Please see the format example from other files in that folder. it is necessary to submit the dict text to this path and name it with `{language}_dict.txt` that contains a list of all characters. Please see the format example from other files in that folder.
2. In folder [ppocr/utils/corpus](./ppocr/utils/corpus),
it is necessary to submit the corpus to this path and name it with `{language}_corpus.txt` that contains a list of words in your language.
Maybe, 50000 words per language is necessary at least.
Of course, the more, the better.
If your language has unique elements, please tell me in advance within any way, such as useful links, wikipedia and so on. If your language has unique elements, please tell me in advance within any way, such as useful links, wikipedia and so on.
More details, please refer to [Multilingual OCR Development Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048). More details, please refer to [Multilingual OCR Development Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
......
...@@ -26,35 +26,57 @@ def parse_args(): ...@@ -26,35 +26,57 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--filename", type=str, help="The name of log which need to analysis.") "--filename", type=str, help="The name of log which need to analysis.")
parser.add_argument( parser.add_argument(
"--log_with_profiler", type=str, help="The path of train log with profiler") "--log_with_profiler",
type=str,
help="The path of train log with profiler")
parser.add_argument( parser.add_argument(
"--profiler_path", type=str, help="The path of profiler timeline log.") "--profiler_path", type=str, help="The path of profiler timeline log.")
parser.add_argument( parser.add_argument(
"--keyword", type=str, help="Keyword to specify analysis data") "--keyword", type=str, help="Keyword to specify analysis data")
parser.add_argument( parser.add_argument(
"--separator", type=str, default=None, help="Separator of different field in log") "--separator",
type=str,
default=None,
help="Separator of different field in log")
parser.add_argument( parser.add_argument(
'--position', type=int, default=None, help='The position of data field') '--position', type=int, default=None, help='The position of data field')
parser.add_argument( parser.add_argument(
'--range', type=str, default="", help='The range of data field to intercept') '--range',
type=str,
default="",
help='The range of data field to intercept')
parser.add_argument( parser.add_argument(
'--base_batch_size', type=int, help='base_batch size on gpu') '--base_batch_size', type=int, help='base_batch size on gpu')
parser.add_argument( parser.add_argument(
'--skip_steps', type=int, default=0, help='The number of steps to be skipped') '--skip_steps',
type=int,
default=0,
help='The number of steps to be skipped')
parser.add_argument( parser.add_argument(
'--model_mode', type=int, default=-1, help='Analysis mode, default value is -1') '--model_mode',
type=int,
default=-1,
help='Analysis mode, default value is -1')
parser.add_argument('--ips_unit', type=str, default=None, help='IPS unit')
parser.add_argument( parser.add_argument(
'--ips_unit', type=str, default=None, help='IPS unit') '--model_name',
parser.add_argument( type=str,
'--model_name', type=str, default=0, help='training model_name, transformer_base') default=0,
help='training model_name, transformer_base')
parser.add_argument( parser.add_argument(
'--mission_name', type=str, default=0, help='training mission name') '--mission_name', type=str, default=0, help='training mission name')
parser.add_argument( parser.add_argument(
'--direction_id', type=int, default=0, help='training direction_id') '--direction_id', type=int, default=0, help='training direction_id')
parser.add_argument( parser.add_argument(
'--run_mode', type=str, default="sp", help='multi process or single process') '--run_mode',
type=str,
default="sp",
help='multi process or single process')
parser.add_argument( parser.add_argument(
'--index', type=int, default=1, help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}') '--index',
type=int,
default=1,
help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
parser.add_argument( parser.add_argument(
'--gpu_num', type=int, default=1, help='nums of training gpus') '--gpu_num', type=int, default=1, help='nums of training gpus')
args = parser.parse_args() args = parser.parse_args()
...@@ -72,7 +94,12 @@ def _is_number(num): ...@@ -72,7 +94,12 @@ def _is_number(num):
class TimeAnalyzer(object): class TimeAnalyzer(object):
def __init__(self, filename, keyword=None, separator=None, position=None, range="-1"): def __init__(self,
filename,
keyword=None,
separator=None,
position=None,
range="-1"):
if filename is None: if filename is None:
raise Exception("Please specify the filename!") raise Exception("Please specify the filename!")
...@@ -99,7 +126,8 @@ class TimeAnalyzer(object): ...@@ -99,7 +126,8 @@ class TimeAnalyzer(object):
# Distil the string from a line. # Distil the string from a line.
line = line.strip() line = line.strip()
line_words = line.split(self.separator) if self.separator else line.split() line_words = line.split(
self.separator) if self.separator else line.split()
if args.position: if args.position:
result = line_words[self.position] result = line_words[self.position]
else: else:
...@@ -108,27 +136,36 @@ class TimeAnalyzer(object): ...@@ -108,27 +136,36 @@ class TimeAnalyzer(object):
if line_words[i] == self.keyword: if line_words[i] == self.keyword:
result = line_words[i + 1] result = line_words[i + 1]
break break
# Distil the result from the picked string. # Distil the result from the picked string.
if not self.range: if not self.range:
result = result[0:] result = result[0:]
elif _is_number(self.range): elif _is_number(self.range):
result = result[0: int(self.range)] result = result[0:int(self.range)]
else: else:
result = result[int(self.range.split(":")[0]): int(self.range.split(":")[1])] result = result[int(self.range.split(":")[0]):int(
self.range.split(":")[1])]
self.records.append(float(result)) self.records.append(float(result))
except Exception as exc: except Exception as exc:
print("line is: {}; separator={}; position={}".format(line, self.separator, self.position)) print("line is: {}; separator={}; position={}".format(
line, self.separator, self.position))
print("Extract {} records: separator={}; position={}".format(len(self.records), self.separator, self.position)) print("Extract {} records: separator={}; position={}".format(
len(self.records), self.separator, self.position))
def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None): def _get_fps(self,
mode,
batch_size,
gpu_num,
avg_of_records,
run_mode,
unit=None):
if mode == -1 and run_mode == 'sp': if mode == -1 and run_mode == 'sp':
assert unit, "Please set the unit when mode is -1." assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records fps = gpu_num * avg_of_records
elif mode == -1 and run_mode == 'mp': elif mode == -1 and run_mode == 'mp':
assert unit, "Please set the unit when mode is -1." assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records #temporarily, not used now fps = gpu_num * avg_of_records #temporarily, not used now
print("------------this is mp") print("------------this is mp")
elif mode == 0: elif mode == 0:
# s/step -> samples/s # s/step -> samples/s
...@@ -155,12 +192,20 @@ class TimeAnalyzer(object): ...@@ -155,12 +192,20 @@ class TimeAnalyzer(object):
return fps, unit return fps, unit
def analysis(self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode='sp', unit=None): def analysis(self,
batch_size,
gpu_num=1,
skip_steps=0,
mode=-1,
run_mode='sp',
unit=None):
if batch_size <= 0: if batch_size <= 0:
print("base_batch_size should larger than 0.") print("base_batch_size should larger than 0.")
return 0, '' return 0, ''
if len(self.records) <= skip_steps: # to address the condition which item of log equals to skip_steps if len(
self.records
) <= skip_steps: # to address the condition which item of log equals to skip_steps
print("no records") print("no records")
return 0, '' return 0, ''
...@@ -180,16 +225,20 @@ class TimeAnalyzer(object): ...@@ -180,16 +225,20 @@ class TimeAnalyzer(object):
skip_max = self.records[i] skip_max = self.records[i]
avg_of_records = sum_of_records / float(count) avg_of_records = sum_of_records / float(count)
avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps) avg_of_records_skipped = sum_of_records_skipped / float(count -
skip_steps)
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, run_mode, unit) fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records,
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit) run_mode, unit)
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num,
avg_of_records_skipped, run_mode, unit)
if mode == -1: if mode == -1:
print("average ips of %d steps, skip 0 step:" % count) print("average ips of %d steps, skip 0 step:" % count)
print("\tAvg: %.3f %s" % (avg_of_records, fps_unit)) print("\tAvg: %.3f %s" % (avg_of_records, fps_unit))
print("\tFPS: %.3f %s" % (fps, fps_unit)) print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0: if skip_steps > 0:
print("average ips of %d steps, skip %d steps:" % (count, skip_steps)) print("average ips of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit)) print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit))
print("\tMin: %.3f %s" % (skip_min, fps_unit)) print("\tMin: %.3f %s" % (skip_min, fps_unit))
print("\tMax: %.3f %s" % (skip_max, fps_unit)) print("\tMax: %.3f %s" % (skip_max, fps_unit))
...@@ -199,7 +248,8 @@ class TimeAnalyzer(object): ...@@ -199,7 +248,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f steps/s" % avg_of_records) print("\tAvg: %.3f steps/s" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit)) print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0: if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps)) print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f steps/s" % avg_of_records_skipped) print("\tAvg: %.3f steps/s" % avg_of_records_skipped)
print("\tMin: %.3f steps/s" % skip_min) print("\tMin: %.3f steps/s" % skip_min)
print("\tMax: %.3f steps/s" % skip_max) print("\tMax: %.3f steps/s" % skip_max)
...@@ -209,7 +259,8 @@ class TimeAnalyzer(object): ...@@ -209,7 +259,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f s/step" % avg_of_records) print("\tAvg: %.3f s/step" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit)) print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0: if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps)) print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f s/step" % avg_of_records_skipped) print("\tAvg: %.3f s/step" % avg_of_records_skipped)
print("\tMin: %.3f s/step" % skip_min) print("\tMin: %.3f s/step" % skip_min)
print("\tMax: %.3f s/step" % skip_max) print("\tMax: %.3f s/step" % skip_max)
...@@ -236,7 +287,8 @@ if __name__ == "__main__": ...@@ -236,7 +287,8 @@ if __name__ == "__main__":
if args.gpu_num == 1: if args.gpu_num == 1:
run_info["log_with_profiler"] = args.log_with_profiler run_info["log_with_profiler"] = args.log_with_profiler
run_info["profiler_path"] = args.profiler_path run_info["profiler_path"] = args.profiler_path
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, args.position, args.range) analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator,
args.position, args.range)
run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis( run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis(
batch_size=args.base_batch_size, batch_size=args.base_batch_size,
gpu_num=args.gpu_num, gpu_num=args.gpu_num,
...@@ -245,29 +297,50 @@ if __name__ == "__main__": ...@@ -245,29 +297,50 @@ if __name__ == "__main__":
run_mode=args.run_mode, run_mode=args.run_mode,
unit=args.ips_unit) unit=args.ips_unit)
try: try:
if int(os.getenv('job_fail_flag')) == 1 or int(run_info["FINAL_RESULT"]) == 0: if int(os.getenv('job_fail_flag')) == 1 or int(run_info[
"FINAL_RESULT"]) == 0:
run_info["JOB_FAIL_FLAG"] = 1 run_info["JOB_FAIL_FLAG"] = 1
except: except:
pass pass
elif args.index == 3: elif args.index == 3:
run_info["FINAL_RESULT"] = {} run_info["FINAL_RESULT"] = {}
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', None, 3, '').records records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead',
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', None, 5).records None, 3, '').records
records_ct_total = TimeAnalyzer(args.filename, 'Computation time', None, 3, '').records records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead',
records_gm_total = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 4, '').records None, 5).records
records_gm_ratio = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 6).records records_ct_total = TimeAnalyzer(args.filename, 'Computation time',
records_gmas_total = TimeAnalyzer(args.filename, 'GpuMemcpyAsync Calls', None, 4, '').records None, 3, '').records
records_gms_total = TimeAnalyzer(args.filename, 'GpuMemcpySync Calls', None, 4, '').records records_gm_total = TimeAnalyzer(args.filename,
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[0] if records_fo_total else 0 'GpuMemcpy Calls',
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[0] if records_fo_ratio else 0 None, 4, '').records
run_info["FINAL_RESULT"]["ComputationTime_Total"] = records_ct_total[0] if records_ct_total else 0 records_gm_ratio = TimeAnalyzer(args.filename,
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[0] if records_gm_total else 0 'GpuMemcpy Calls',
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[0] if records_gm_ratio else 0 None, 6).records
run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = records_gmas_total[0] if records_gmas_total else 0 records_gmas_total = TimeAnalyzer(args.filename,
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[0] if records_gms_total else 0 'GpuMemcpyAsync Calls',
None, 4, '').records
records_gms_total = TimeAnalyzer(args.filename,
'GpuMemcpySync Calls',
None, 4, '').records
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[
0] if records_fo_total else 0
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[
0] if records_fo_ratio else 0
run_info["FINAL_RESULT"][
"ComputationTime_Total"] = records_ct_total[
0] if records_ct_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[
0] if records_gm_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[
0] if records_gm_ratio else 0
run_info["FINAL_RESULT"][
"GpuMemcpyAsync_Total"] = records_gmas_total[
0] if records_gmas_total else 0
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[
0] if records_gms_total else 0
else: else:
print("Not support!") print("Not support!")
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
print("{}".format(json.dumps(run_info))) # it's required, for the log file path insert to the database print("{}".format(json.dumps(run_info))
) # it's required, for the log file path insert to the database
...@@ -58,3 +58,4 @@ source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合 ...@@ -58,3 +58,4 @@ source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合
_set_params $@ _set_params $@
#_train # 如果只想产出训练log,不解析,可取消注释 #_train # 如果只想产出训练log,不解析,可取消注释
_run # 该函数在run_model.sh中,执行时会调用_train; 如果不联调只想要产出训练log可以注掉本行,提交时需打开 _run # 该函数在run_model.sh中,执行时会调用_train; 如果不联调只想要产出训练log可以注掉本行,提交时需打开
...@@ -36,3 +36,4 @@ for model_mode in ${model_mode_list[@]}; do ...@@ -36,3 +36,4 @@ for model_mode in ${model_mode_list[@]}; do
done done
Global: Global:
use_gpu: true use_gpu: true
use_xpu: false
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
Backbone: Backbone:
name: LayoutXLMForRe name: LayoutXLMForRe
pretrained: True pretrained: True
checkpoints: checkpoints:
Loss: Loss:
name: LossFromOutput name: LossFromOutput
...@@ -35,6 +35,7 @@ Optimizer: ...@@ -35,6 +35,7 @@ Optimizer:
clip_norm: 10 clip_norm: 10
lr: lr:
learning_rate: 0.00005 learning_rate: 0.00005
warmup_epoch: 10
regularizer: regularizer:
name: L2 name: L2
factor: 0.00000 factor: 0.00000
...@@ -81,7 +82,7 @@ Train: ...@@ -81,7 +82,7 @@ Train:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 4 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
Eval: Eval:
...@@ -118,5 +119,5 @@ Eval: ...@@ -118,5 +119,5 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 4 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment