Commit 5f77267a authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents dc67dc3e 967f0676
......@@ -10,7 +10,6 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# pyrcc5 -o libs/resources.py resources.qrc
......@@ -24,13 +23,11 @@ import subprocess
import sys
from functools import partial
try:
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
print("Please install pyqt5...")
from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPointF, QProcess
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, \
QFileDialog, QListWidgetItem, QComboBox, QDialog
__dir__ = os.path.dirname(os.path.abspath(__file__))
......@@ -42,6 +39,7 @@ sys.path.append("..")
from paddleocr import PaddleOCR
from libs.constants import *
from libs.utils import *
from libs.labelColor import label_colormap
from libs.settings import Settings
from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR, DEFAULT_LOCK_COLOR
from libs.stringBundle import StringBundle
......@@ -53,9 +51,13 @@ from libs.colorDialog import ColorDialog
from libs.ustr import ustr
from libs.hashableQListWidgetItem import HashableQListWidgetItem
from libs.editinlist import EditInList
from libs.unique_label_qlist_widget import UniqueLabelQListWidget
from libs.keyDialog import KeyDialog
__appname__ = 'PPOCRLabel'
LABEL_COLORMAP = label_colormap()
class MainWindow(QMainWindow):
FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = list(range(3))
......@@ -63,6 +65,7 @@ class MainWindow(QMainWindow):
def __init__(self,
lang="ch",
gpu=False,
kie_mode=False,
default_filename=None,
default_predefined_class_file=None,
default_save_dir=None):
......@@ -76,12 +79,19 @@ class MainWindow(QMainWindow):
self.settings.load()
settings = self.settings
self.lang = lang
# Load string bundle for i18n
if lang not in ['ch', 'en']:
lang = 'en'
self.stringBundle = StringBundle.getBundle(localeStr='zh-CN' if lang == 'ch' else 'en') # 'en'
getStr = lambda strId: self.stringBundle.getString(strId)
# KIE setting
self.kie_mode = kie_mode
self.key_previous_text = ""
self.existed_key_cls_set = set()
self.key_dialog_tip = getStr('keyDialogTip')
self.defaultSaveDir = default_save_dir
self.ocr = PaddleOCR(use_pdserving=False,
use_angle_cls=True,
......@@ -133,11 +143,13 @@ class MainWindow(QMainWindow):
self.autoSaveNum = 5
# ================== File List ==================
filelistLayout = QVBoxLayout()
filelistLayout.setContentsMargins(0, 0, 0, 0)
self.fileListWidget = QListWidget()
self.fileListWidget.itemClicked.connect(self.fileitemDoubleClicked)
self.fileListWidget.setIconSize(QSize(25, 25))
filelistLayout = QVBoxLayout()
filelistLayout.setContentsMargins(0, 0, 0, 0)
filelistLayout.addWidget(self.fileListWidget)
self.AutoRecognition = QToolButton()
......@@ -158,10 +170,24 @@ class MainWindow(QMainWindow):
self.fileDock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.fileDock)
# ================== Key List ==================
if self.kie_mode:
# self.keyList = QListWidget()
self.keyList = UniqueLabelQListWidget()
# self.keyList.itemSelectionChanged.connect(self.keyListSelectionChanged)
# self.keyList.itemDoubleClicked.connect(self.editBox)
# self.keyList.itemChanged.connect(self.keyListItemChanged)
self.keyListDockName = getStr('keyListTitle')
self.keyListDock = QDockWidget(self.keyListDockName, self)
self.keyListDock.setWidget(self.keyList)
self.keyListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
filelistLayout.addWidget(self.keyListDock)
# ================== Right Area ==================
listLayout = QVBoxLayout()
listLayout.setContentsMargins(0, 0, 0, 0)
# Buttons
self.editButton = QToolButton()
self.reRecogButton = QToolButton()
self.reRecogButton.setIcon(newIcon('reRec', 30))
......@@ -174,12 +200,12 @@ class MainWindow(QMainWindow):
self.DelButton = QToolButton()
self.DelButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
lefttoptoolbox = QHBoxLayout()
lefttoptoolbox.addWidget(self.newButton)
lefttoptoolbox.addWidget(self.reRecogButton)
lefttoptoolboxcontainer = QWidget()
lefttoptoolboxcontainer.setLayout(lefttoptoolbox)
listLayout.addWidget(lefttoptoolboxcontainer)
leftTopToolBox = QHBoxLayout()
leftTopToolBox.addWidget(self.newButton)
leftTopToolBox.addWidget(self.reRecogButton)
leftTopToolBoxContainer = QWidget()
leftTopToolBoxContainer.setLayout(leftTopToolBox)
listLayout.addWidget(leftTopToolBoxContainer)
# ================== Label List ==================
# Create and add a widget for showing current label items
......@@ -341,7 +367,7 @@ class MainWindow(QMainWindow):
resetAll = action(getStr('resetAll'), self.resetAll, None, 'resetall', getStr('resetAllDetail'))
color1 = action(getStr('boxLineColor'), self.chooseColor1,
color1 = action(getStr('boxLineColor'), self.chooseColor,
'Ctrl+L', 'color_line', getStr('boxLineColorDetail'))
createMode = action(getStr('crtBox'), self.setCreateMode,
......@@ -402,11 +428,12 @@ class MainWindow(QMainWindow):
self.MANUAL_ZOOM: lambda: 1,
}
# ================== New Actions ==================
edit = action(getStr('editLabel'), self.editLabel,
'Ctrl+E', 'edit', getStr('editLabelDetail'),
enabled=False)
# ================== New Actions ==================
AutoRec = action(getStr('autoRecognition'), self.autoRecognition,
'', 'Auto', getStr('autoRecognition'), enabled=False)
......@@ -437,6 +464,9 @@ class MainWindow(QMainWindow):
undo = action(getStr("undo"), self.undoShapeEdit,
'Ctrl+Z', "undo", getStr("undo"), enabled=False)
change_cls = action(getStr("keyChange"), self.change_box_key,
'Ctrl+B', "edit", getStr("keyChange"), enabled=False)
lock = action(getStr("lockBox"), self.lockSelectedShape,
None, "lock", getStr("lockBoxDetail"),
enabled=False)
......@@ -482,8 +512,7 @@ class MainWindow(QMainWindow):
addActions(labelMenu, (edit, delete))
self.labelList.setContextMenuPolicy(Qt.CustomContextMenu)
self.labelList.customContextMenuRequested.connect(
self.popLabelListMenu)
self.labelList.customContextMenuRequested.connect(self.popLabelListMenu)
# Draw squares/rectangles
self.drawSquaresOption = QAction(getStr('drawSquares'), self)
......@@ -499,14 +528,15 @@ class MainWindow(QMainWindow):
shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
fitWindow=fitWindow, fitWidth=fitWidth,
zoomActions=zoomActions, saveLabel=saveLabel,
zoomActions=zoomActions, saveLabel=saveLabel, change_cls=change_cls,
undo=undo, undoLastPoint=undoLastPoint, open_dataset_dir=open_dataset_dir,
rotateLeft=rotateLeft, rotateRight=rotateRight, lock=lock,
fileMenuActions=(opendir, open_dataset_dir, saveLabel, resetAll, quit),
beginner=(), advanced=(),
editMenu=(createpoly, edit, copy, delete, singleRere, None, undo, undoLastPoint,
None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption, lock),
beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight, lock),
beginnerContext=(
create, edit, copy, delete, singleRere, rotateLeft, rotateRight, lock, change_cls),
advancedContext=(createMode, editMode, edit, copy,
delete, shapeLineColor, shapeFillColor),
onLoadActive=(create, createMode, editMode),
......@@ -615,6 +645,8 @@ class MainWindow(QMainWindow):
elif self.filePath:
self.queueEvent(partial(self.loadFile, self.filePath or ""))
self.keyDialog = None
# Callbacks:
self.zoomWidget.valueChanged.connect(self.paintCanvas)
......@@ -949,6 +981,12 @@ class MainWindow(QMainWindow):
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
self.BoxList.scrollToItem(self.currentBox())
if self.kie_mode:
if len(self.canvas.selectedShapes) == 1 and self.keyList.count() > 0:
selected_key_item_row = self.keyList.findItemsByLabel(self.canvas.selectedShapes[0].key_cls,
get_row=True)
self.keyList.setCurrentRow(selected_key_item_row)
self._noSelectionSlot = False
n_selected = len(selected_shapes)
self.actions.singleRere.setEnabled(n_selected)
......@@ -956,6 +994,7 @@ class MainWindow(QMainWindow):
self.actions.copy.setEnabled(n_selected)
self.actions.edit.setEnabled(n_selected == 1)
self.actions.lock.setEnabled(n_selected)
self.actions.change_cls.setEnabled(n_selected)
def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked()
......@@ -1002,8 +1041,8 @@ class MainWindow(QMainWindow):
def loadLabels(self, shapes):
s = []
for label, points, line_color, fill_color, difficult in shapes:
shape = Shape(label=label, line_color=line_color)
for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points:
# Ensure the labels are within the bounds of the image. If not, fix them.
......@@ -1017,16 +1056,7 @@ class MainWindow(QMainWindow):
shape.close()
s.append(shape)
# if line_color:
# shape.line_color = QColor(*line_color)
# else:
# shape.line_color = generateColorByText(label)
#
# if fill_color:
# shape.fill_color = QColor(*fill_color)
# else:
# shape.fill_color = generateColorByText(label)
self._update_shape_color(shape)
self.addLabel(shape)
self.updateComboBox()
......@@ -1066,14 +1096,16 @@ class MainWindow(QMainWindow):
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
difficult=s.difficult,
key_cls=s.key_cls) # bool
shapes = [] if mode == 'Auto' else \
[format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
if mode == 'Auto':
shapes = []
else:
shapes = [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
# Can add differrent annotation formats here
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
trans_dic = {"label": box[1][0], "points": box[0], "difficult": False, "key_cls": "None"}
if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)
......@@ -1081,8 +1113,8 @@ class MainWindow(QMainWindow):
try:
trans_dic = []
for box in shapes:
trans_dic.append(
{"transcription": box['label'], "points": box['points'], 'difficult': box['difficult']})
trans_dic.append({"transcription": box['label'], "points": box['points'],
"difficult": box['difficult'], "key_cls": box['key_cls']})
self.PPlabel[annotationFilePath] = trans_dic
if mode == 'Auto':
self.Cachelabel[annotationFilePath] = trans_dic
......@@ -1148,8 +1180,7 @@ class MainWindow(QMainWindow):
position MUST be in global coordinates.
"""
if len(self.labelHist) > 0:
self.labelDialog = LabelDialog(
parent=self, listItem=self.labelHist)
self.labelDialog = LabelDialog(parent=self, listItem=self.labelHist)
if value:
text = self.labelDialog.popUp(text=self.prevLabelText)
......@@ -1159,8 +1190,22 @@ class MainWindow(QMainWindow):
if text is not None:
self.prevLabelText = self.stringBundle.getString('tempLabel')
# generate_color = generateColorByText(text)
shape = self.canvas.setLastLabel(text, None, None) # generate_color, generate_color
shape = self.canvas.setLastLabel(text, None, None, None) # generate_color, generate_color
if self.kie_mode:
key_text, _ = self.keyDialog.popUp(self.key_previous_text)
if key_text is not None:
shape = self.canvas.setLastLabel(text, None, None, key_text) # generate_color, generate_color
self.key_previous_text = key_text
if not self.keyList.findItemsByLabel(key_text):
item = self.keyList.createItemFromLabel(key_text)
self.keyList.addItem(item)
rgb = self._get_rgb_by_label(key_text, self.kie_mode)
self.keyList.setItemLabel(item, key_text, rgb)
self._update_shape_color(shape)
self.keyDialog.addLabelHistory(key_text)
self.addLabel(shape)
if self.beginner(): # Switch to edit mode.
self.canvas.setEditing(True)
......@@ -1175,6 +1220,25 @@ class MainWindow(QMainWindow):
# self.canvas.undoLastLine()
self.canvas.resetAllLines()
def _update_shape_color(self, shape):
r, g, b = self._get_rgb_by_label(shape.key_cls, self.kie_mode)
shape.line_color = QColor(r, g, b)
shape.vertex_fill_color = QColor(r, g, b)
shape.hvertex_fill_color = QColor(255, 255, 255)
shape.fill_color = QColor(r, g, b, 128)
shape.select_line_color = QColor(255, 255, 255)
shape.select_fill_color = QColor(r, g, b, 155)
def _get_rgb_by_label(self, label, kie_mode):
shift_auto_shape_color = 2 # use for random color
if kie_mode and label != "None":
item = self.keyList.findItemsByLabel(label)[0]
label_id = self.keyList.indexFromItem(item).row() + 1
label_id += shift_auto_shape_color
return LABEL_COLORMAP[label_id % len(LABEL_COLORMAP)]
else:
return (0, 255, 0)
def scrollRequest(self, delta, orientation):
units = - delta / (8 * 15)
bar = self.scrollBars[orientation]
......@@ -1344,7 +1408,7 @@ class MainWindow(QMainWindow):
select_indexes = self.fileListWidget.selectedIndexes()
if len(select_indexes) > 0:
self.fileDock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
f"/{self.fileListWidget.count()})")
f"/{self.fileListWidget.count()})")
# update show counting
self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
......@@ -1362,13 +1426,13 @@ class MainWindow(QMainWindow):
for box in self.canvas.lockedShapes:
if self.canvas.isInTheSameImage:
shapes.append((box['transcription'], [[s[0] * width, s[1] * height] for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult']))
else:
shapes.append(('锁定框:待检测', [[s[0] * width, s[1] * height] for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult']))
if imgidx in self.PPlabel.keys():
for box in self.PPlabel[imgidx]:
shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
shapes.append((box['transcription'], box['points'], None, box['key_cls'], box['difficult']))
self.loadLabels(shapes)
self.canvas.verified = False
......@@ -1504,6 +1568,39 @@ class MainWindow(QMainWindow):
self.actions.open_dataset_dir.setEnabled(False)
defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
def init_key_list(self, label_dict):
if not self.kie_mode:
return
# load key_cls
for image, info in label_dict.items():
for box in info:
if "key_cls" not in box:
continue
self.existed_key_cls_set.add(box["key_cls"])
if len(self.existed_key_cls_set) > 0:
for key_text in self.existed_key_cls_set:
if not self.keyList.findItemsByLabel(key_text):
item = self.keyList.createItemFromLabel(key_text)
self.keyList.addItem(item)
rgb = self._get_rgb_by_label(key_text, self.kie_mode)
self.keyList.setItemLabel(item, key_text, rgb)
if self.keyDialog is None:
# key list dialog
self.keyDialog = KeyDialog(
text=self.key_dialog_tip,
parent=self,
labels=self.existed_key_cls_set,
sort_labels=True,
show_text_field=True,
completion="startswith",
fit_to_content={'column': True, 'row': False},
flags=None
)
else:
self.keyDialog.labelList.addItems(self.existed_key_cls_set)
def importDirImages(self, dirpath, isDelete=False):
if not self.mayContinue() or not dirpath:
return
......@@ -1518,6 +1615,9 @@ class MainWindow(QMainWindow):
self.Cachelabel = self.loadLabelFile(self.Cachelabelpath)
if self.Cachelabel:
self.PPlabel = dict(self.Cachelabel, **self.PPlabel)
self.init_key_list(self.PPlabel)
self.lastOpenDir = dirpath
self.dirname = dirpath
......@@ -1737,7 +1837,7 @@ class MainWindow(QMainWindow):
def currentPath(self):
return os.path.dirname(self.filePath) if self.filePath else '.'
def chooseColor1(self):
def chooseColor(self):
color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
default=DEFAULT_LINE_COLOR)
if color:
......@@ -1854,6 +1954,8 @@ class MainWindow(QMainWindow):
self.setDirty()
self.saveCacheLabel()
self.init_key_list(self.Cachelabel)
def reRecognition(self):
img = cv2.imread(self.filePath)
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
......@@ -2059,7 +2161,8 @@ class MainWindow(QMainWindow):
try:
img = cv2.imread(key)
for i, label in enumerate(self.PPlabel[idx]):
if label['difficult']: continue
if label['difficult']:
continue
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_' + str(i) + '.jpg'
cv2.imwrite(crop_img_dir + img_name, img_crop)
......@@ -2096,6 +2199,15 @@ class MainWindow(QMainWindow):
self.autoSaveNum = 5 # Used for backup
print('The program will automatically save once after confirming 5 images (default)')
def change_box_key(self):
key_text, _ = self.keyDialog.popUp(self.key_previous_text)
if key_text is None:
return
self.key_previous_text = key_text
for shape in self.canvas.selectedShapes:
shape.key_cls = key_text
self._update_shape_color(shape)
def undoShapeEdit(self):
self.canvas.restoreShape()
self.labelList.clear()
......@@ -2126,8 +2238,9 @@ class MainWindow(QMainWindow):
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
ratio=[[int(p.x()) / width, int(p.y()) / height] for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
difficult=s.difficult, # bool
key_cls=s.key_cls, # bool
)
# lock
if len(self.canvas.lockedShapes) == 0:
......@@ -2137,7 +2250,9 @@ class MainWindow(QMainWindow):
shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
trans_dic = []
for box in shapes:
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']})
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'],
"difficult": box['difficult'],
"key_cls": "None" if "key_cls" not in box else box["key_cls"]})
self.canvas.lockedShapes = trans_dic
self.actions.save.setEnabled(True)
......@@ -2179,6 +2294,7 @@ def get_main_app(argv=[]):
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--lang", type=str, default='en', nargs="?")
arg_parser.add_argument("--gpu", type=str2bool, default=True, nargs="?")
arg_parser.add_argument("--kie", type=str2bool, default=False, nargs="?")
arg_parser.add_argument("--predefined_classes_file",
default=os.path.join(os.path.dirname(__file__), "data", "predefined_classes.txt"),
nargs="?")
......@@ -2186,6 +2302,7 @@ def get_main_app(argv=[]):
win = MainWindow(lang=args.lang,
gpu=args.gpu,
kie_mode=args.kie,
default_predefined_class_file=args.predefined_classes_file)
win.show()
return app, win
......
......@@ -8,6 +8,8 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
### Recent Update
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- Added KIE mode, for [detection + identification + keyword extraction] labeling.
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference
- 2021.11.17:
......@@ -72,7 +74,8 @@ PPOCRLabel
```bash
pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32
PPOCRLabel # run
PPOCRLabel # [Normal mode] for [detection + recognition] labeling
PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
#### 1.2.2 Build and Install the Whl Package Locally
......@@ -87,7 +90,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```bash
cd ./PPOCRLabel # Switch to the PPOCRLabel directory
python PPOCRLabel.py
python PPOCRLabel.py # [Normal mode] for [detection + recognition] labeling
python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
......@@ -198,21 +202,31 @@ For some data that are difficult to recognize, the recognition results will not
- Enter the following command in the terminal to execute the dataset division script:
```
```
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
```
Parameter Description:
- `trainValTestRatio` is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is `6:2:2`
- `labelRootPath` is the storage path of the dataset labeled by PPOCRLabel, the default is `../train_data/label`
- `detRootPath` is the path where the text detection dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/det`
- `recRootPath` is the path where the character recognition dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/rec`
- `datasetRootPath` is the storage path of the complete dataset labeled by PPOCRLabel. The default path is `PaddleOCR/train_data` .
```
|-train_data
|-crop_img
|- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 Error message
- If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated.
......
......@@ -8,6 +8,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
#### 近期更新
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- 新增:KIE 功能,用于打【检测+识别+关键字提取】的标签
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题
- 2021.11.17:
......@@ -70,7 +72,8 @@ PPOCRLabel --lang ch
```bash
pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添加"-i https://mirror.baidu.com/pypi/simple"
PPOCRLabel --lang ch # 启动
PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
```
> 如果上述安装出现问题,可以参考3.6节 错误提示
......@@ -89,7 +92,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.
```bash
cd ./PPOCRLabel # 切换到PPOCRLabel目录
python PPOCRLabel.py --lang ch
python PPOCRLabel.py --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
```
......@@ -185,19 +189,29 @@ PPOCRLabel支持三种导出方式:
```
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
```
参数说明:
- `trainValTestRatio` 是训练集、验证集、测试集的图像数量划分比例,根据实际情况设定,默认是`6:2:2`
- `labelRootPath` 是PPOCRLabel标注的数据集存放路径,默认是`../train_data/label`
- `detRootPath` 是根据PPOCRLabel标注的数据集划分后的文本检测数据集存放的路径,默认是`../train_data/det `
- `recRootPath` 是根据PPOCRLabel标注的数据集划分后的字符识别数据集存放的路径,默认是`../train_data/rec`
- `datasetRootPath` 是PPOCRLabel标注的完整数据集存放路径。默认路径是 `PaddleOCR/train_data` 分割数据集前应有如下结构:
```
|-train_data
|-crop_img
|- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 错误提示
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
......
......@@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集
labelPath = os.path.join(root, dir)
labelAbsPath = os.path.abspath(labelPath)
dataAbsPath = os.path.abspath(root)
if flag == "det":
labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
elif flag == "rec":
labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines()
......@@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath,
imageName = os.path.basename(imageRelativePath)
if flag == "det":
imagePath = os.path.join(labelAbsPath, imageName)
imagePath = os.path.join(dataAbsPath, imageName)
elif flag == "rec":
imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
# 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":")
......@@ -90,15 +89,20 @@ def genDetRecTrainVal(args):
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
for root, dirs, files in os.walk(args.labelRootPath):
splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs:
splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
if dir == 'crop_img':
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
else:
continue
break
if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
......@@ -110,9 +114,9 @@ if __name__ == "__main__":
default="6:2:2",
help="ratio of trainset:valset:testset")
parser.add_argument(
"--labelRootPath",
"--datasetRootPath",
type=str,
default="../train_data/label",
default="../train_data/",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
)
parser.add_argument(
......
......@@ -783,7 +783,7 @@ class Canvas(QWidget):
points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)]
return True in map(self.outOfPixmap, points)
def setLastLabel(self, text, line_color = None, fill_color = None):
def setLastLabel(self, text, line_color=None, fill_color=None, key_cls=None):
assert text
self.shapes[-1].label = text
if line_color:
......@@ -791,6 +791,10 @@ class Canvas(QWidget):
if fill_color:
self.shapes[-1].fill_color = fill_color
if key_cls:
self.shapes[-1].key_cls = key_cls
self.storeShapes()
return self.shapes[-1]
......
import re
from PyQt5 import QtCore
from PyQt5 import QtGui
from PyQt5 import QtWidgets
from PyQt5.Qt import QT_VERSION_STR
from libs.utils import newIcon, labelValidator
QT5 = QT_VERSION_STR[0] == '5'
# TODO(unknown):
# - Calculate optimal position so as not to go out of screen area.
class KeyQLineEdit(QtWidgets.QLineEdit):
def setListWidget(self, list_widget):
self.list_widget = list_widget
def keyPressEvent(self, e):
if e.key() in [QtCore.Qt.Key_Up, QtCore.Qt.Key_Down]:
self.list_widget.keyPressEvent(e)
else:
super(KeyQLineEdit, self).keyPressEvent(e)
class KeyDialog(QtWidgets.QDialog):
def __init__(
self,
text="Enter object label",
parent=None,
labels=None,
sort_labels=True,
show_text_field=True,
completion="startswith",
fit_to_content=None,
flags=None,
):
if fit_to_content is None:
fit_to_content = {"row": False, "column": True}
self._fit_to_content = fit_to_content
super(KeyDialog, self).__init__(parent)
self.edit = KeyQLineEdit()
self.edit.setPlaceholderText(text)
self.edit.setValidator(labelValidator())
self.edit.editingFinished.connect(self.postProcess)
if flags:
self.edit.textChanged.connect(self.updateFlags)
layout = QtWidgets.QVBoxLayout()
if show_text_field:
layout_edit = QtWidgets.QHBoxLayout()
layout_edit.addWidget(self.edit, 6)
layout.addLayout(layout_edit)
# buttons
self.buttonBox = bb = QtWidgets.QDialogButtonBox(
QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
QtCore.Qt.Horizontal,
self,
)
bb.button(bb.Ok).setIcon(newIcon("done"))
bb.button(bb.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
# label_list
self.labelList = QtWidgets.QListWidget()
if self._fit_to_content["row"]:
self.labelList.setHorizontalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
if self._fit_to_content["column"]:
self.labelList.setVerticalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
self._sort_labels = sort_labels
if labels:
self.labelList.addItems(labels)
if self._sort_labels:
self.labelList.sortItems()
else:
self.labelList.setDragDropMode(
QtWidgets.QAbstractItemView.InternalMove
)
self.labelList.currentItemChanged.connect(self.labelSelected)
self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
self.edit.setListWidget(self.labelList)
layout.addWidget(self.labelList)
# label_flags
if flags is None:
flags = {}
self._flags = flags
self.flagsLayout = QtWidgets.QVBoxLayout()
self.resetFlags()
layout.addItem(self.flagsLayout)
self.edit.textChanged.connect(self.updateFlags)
self.setLayout(layout)
# completion
completer = QtWidgets.QCompleter()
if not QT5 and completion != "startswith":
completion = "startswith"
if completion == "startswith":
completer.setCompletionMode(QtWidgets.QCompleter.InlineCompletion)
# Default settings.
# completer.setFilterMode(QtCore.Qt.MatchStartsWith)
elif completion == "contains":
completer.setCompletionMode(QtWidgets.QCompleter.PopupCompletion)
completer.setFilterMode(QtCore.Qt.MatchContains)
else:
raise ValueError("Unsupported completion: {}".format(completion))
completer.setModel(self.labelList.model())
self.edit.setCompleter(completer)
def addLabelHistory(self, label):
if self.labelList.findItems(label, QtCore.Qt.MatchExactly):
return
self.labelList.addItem(label)
if self._sort_labels:
self.labelList.sortItems()
def labelSelected(self, item):
self.edit.setText(item.text())
def validate(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
if text:
self.accept()
def labelDoubleClicked(self, item):
self.validate()
def postProcess(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
self.edit.setText(text)
def updateFlags(self, label_new):
# keep state of shared flags
flags_old = self.getFlags()
flags_new = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label_new):
for key in keys:
flags_new[key] = flags_old.get(key, False)
self.setFlags(flags_new)
def deleteFlags(self):
for i in reversed(range(self.flagsLayout.count())):
item = self.flagsLayout.itemAt(i).widget()
self.flagsLayout.removeWidget(item)
item.setParent(None)
def resetFlags(self, label=""):
flags = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label):
for key in keys:
flags[key] = False
self.setFlags(flags)
def setFlags(self, flags):
self.deleteFlags()
for key in flags:
item = QtWidgets.QCheckBox(key, self)
item.setChecked(flags[key])
self.flagsLayout.addWidget(item)
item.show()
def getFlags(self):
flags = {}
for i in range(self.flagsLayout.count()):
item = self.flagsLayout.itemAt(i).widget()
flags[item.text()] = item.isChecked()
return flags
def popUp(self, text=None, move=True, flags=None):
if self._fit_to_content["row"]:
self.labelList.setMinimumHeight(
self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
)
if self._fit_to_content["column"]:
self.labelList.setMinimumWidth(
self.labelList.sizeHintForColumn(0) + 2
)
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
self.edit.setText(text)
self.edit.setSelection(0, len(text))
items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
if items:
if len(items) != 1:
self.labelList.setCurrentItem(items[0])
row = self.labelList.row(items[0])
self.edit.completer().setCurrentRow(row)
self.edit.setFocus(QtCore.Qt.PopupFocusReason)
if move:
self.move(QtGui.QCursor.pos())
if self.exec_():
return self.edit.text(), self.getFlags()
else:
return None, None
import PIL.Image
import numpy as np
def rgb2hsv(rgb):
# type: (np.ndarray) -> np.ndarray
"""Convert rgb to hsv.
Parameters
----------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Input rgb image.
Returns
-------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Output hsv image.
"""
hsv = PIL.Image.fromarray(rgb, mode="RGB")
hsv = hsv.convert("HSV")
hsv = np.array(hsv)
return hsv
def hsv2rgb(hsv):
# type: (np.ndarray) -> np.ndarray
"""Convert hsv to rgb.
Parameters
----------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Input hsv image.
Returns
-------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Output rgb image.
"""
rgb = PIL.Image.fromarray(hsv, mode="HSV")
rgb = rgb.convert("RGB")
rgb = np.array(rgb)
return rgb
def label_colormap(n_label=256, value=None):
"""Label colormap.
Parameters
----------
n_label: int
Number of labels (default: 256).
value: float or int
Value scale or value of label color in HSV space.
Returns
-------
cmap: numpy.ndarray, (N, 3), numpy.uint8
Label id to colormap.
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
cmap = np.zeros((n_label, 3), dtype=np.uint8)
for i in range(0, n_label):
id = i
r, g, b = 0, 0, 0
for j in range(0, 8):
r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
if value is not None:
hsv = rgb2hsv(cmap.reshape(1, -1, 3))
if isinstance(value, float):
hsv[:, 1:, 2] = hsv[:, 1:, 2].astype(float) * value
else:
assert isinstance(value, int)
hsv[:, 1:, 2] = value
cmap = hsv2rgb(hsv).reshape(-1, 3)
return cmap
......@@ -46,12 +46,13 @@ class Shape(object):
point_size = 8
scale = 1.0
def __init__(self, label=None, line_color=None, difficult=False, paintLabel=False):
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
self.label = label
self.points = []
self.fill = False
self.selected = False
self.difficult = difficult
self.key_cls = key_cls
self.paintLabel = paintLabel
self.locked = False
self.direction = 0
......@@ -224,6 +225,7 @@ class Shape(object):
if self.fill_color != Shape.fill_color:
shape.fill_color = self.fill_color
shape.difficult = self.difficult
shape.key_cls = self.key_cls
return shape
def __len__(self):
......
# -*- encoding: utf-8 -*-
from PyQt5.QtCore import Qt
from PyQt5 import QtWidgets
class EscapableQListWidget(QtWidgets.QListWidget):
def keyPressEvent(self, event):
super(EscapableQListWidget, self).keyPressEvent(event)
if event.key() == Qt.Key_Escape:
self.clearSelection()
class UniqueLabelQListWidget(EscapableQListWidget):
def mousePressEvent(self, event):
super(UniqueLabelQListWidget, self).mousePressEvent(event)
if not self.indexAt(event.pos()).isValid():
self.clearSelection()
def findItemsByLabel(self, label, get_row=False):
items = []
for row in range(self.count()):
item = self.item(row)
if item.data(Qt.UserRole) == label:
items.append(item)
if get_row:
return row
return items
def createItemFromLabel(self, label):
item = QtWidgets.QListWidgetItem()
item.setData(Qt.UserRole, label)
return item
def setItemLabel(self, item, label, color=None):
qlabel = QtWidgets.QLabel()
if color is None:
qlabel.setText(f"{label}")
else:
qlabel.setText('<font color="#{:02x}{:02x}{:02x}">●</font> {} '.format(*color, label))
qlabel.setAlignment(Qt.AlignBottom)
item.setSizeHint(qlabel.sizeHint())
self.setItemWidget(item, qlabel)
......@@ -10,30 +10,26 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from math import sqrt
from libs.ustr import ustr
import hashlib
import os
import re
import sys
from math import sqrt
import cv2
import numpy as np
import os
from PyQt5.QtCore import QRegExp, QT_VERSION_STR
from PyQt5.QtGui import QIcon, QRegExpValidator, QColor
from PyQt5.QtWidgets import QPushButton, QAction, QMenu
from libs.ustr import ustr
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
__iconpath__ = os.path.abspath(os.path.join(__dir__, '../resources/icons'))
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
def newIcon(icon, iconSize=None):
if iconSize is not None:
return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize,iconSize))
return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize, iconSize))
else:
return QIcon(__iconpath__ + "/" + icon + ".png")
......@@ -105,24 +101,25 @@ def generateColorByText(text):
s = ustr(text)
hashCode = int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16)
r = int((hashCode / 255) % 255)
g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255)
g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255)
return QColor(r, g, b, 100)
def have_qstring():
'''p3/qt5 get rid of QString wrapper as py3 has native unicode str type'''
return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith('5.'))
def util_qt_strlistclass():
return QStringList if have_qstring() else list
def natural_sort(list, key=lambda s:s):
def natural_sort(list, key=lambda s: s):
"""
Sort the list into natural alphanumeric order.
"""
def get_alphanum_key_func(key):
convert = lambda text: int(text) if text.isdigit() else text
return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
sort_key = get_alphanum_key_func(key)
list.sort(key=sort_key)
......@@ -133,8 +130,8 @@ def get_rotate_crop_image(img, points):
d = 0.0
for index in range(-1, 3):
d += -0.5 * (points[index + 1][1] + points[index][1]) * (
points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise
points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise
tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1]
......@@ -163,10 +160,11 @@ def get_rotate_crop_image(img, points):
except Exception as e:
print(e)
def stepsInfo(lang='en'):
if lang == 'ch':
msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \
"2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n"\
"2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n" \
"3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n" \
"4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动" \
"绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n" \
......@@ -181,25 +179,26 @@ def stepsInfo(lang='en'):
else:
msg = "1. Build and launch using the instructions above.\n" \
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
"3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name."\
"4. Create Box:\n"\
"4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n"\
"4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n"\
"5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n"\
"6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n"\
"7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n"\
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n" \
"3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name." \
"4. Create Box:\n" \
"4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n" \
"4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n" \
"5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n" \
"6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n" \
"7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n" \
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n" \
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n" \
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n" \
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
return msg
def keysInfo(lang='en'):
if lang == 'ch':
msg = "快捷键\t\t\t说明\n" \
"———————————————————————\n"\
"———————————————————————\n" \
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
"W\t\t\t新建矩形框\n" \
"Q\t\t\t新建四点框\n" \
......@@ -223,17 +222,17 @@ def keysInfo(lang='en'):
"———————————————————————\n" \
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
"\t\t\tof the current image\n" \
"\n"\
"\n" \
"W\t\t\tCreate a rect box\n" \
"Q\t\t\tCreate a four-points box\n" \
"Ctrl + E\t\tEdit label of the selected box\n" \
"Ctrl + R\t\tRe-recognize the selected box\n" \
"Ctrl + C\t\tCopy and paste the selected\n" \
"\t\t\tbox\n" \
"\n"\
"\n" \
"Ctrl + Left Mouse\tMulti select the label\n" \
"Button\t\t\tbox\n" \
"\n"\
"\n" \
"Backspace\t\tDelete the selected box\n" \
"Ctrl + V\t\tCheck image\n" \
"Ctrl + Shift + d\tDelete image\n" \
......@@ -245,4 +244,4 @@ def keysInfo(lang='en'):
"———————————————————————\n" \
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
return msg
\ No newline at end of file
return msg
......@@ -106,4 +106,7 @@ undo=Undo
undoLastPoint=Undo Last Point
autoSaveMode=Auto Export Label Mode
lockBox=Lock selected box/Unlock all box
lockBoxDetail=Lock selected box/Unlock all box
\ No newline at end of file
lockBoxDetail=Lock selected box/Unlock all box
keyListTitle=Key List
keyDialogTip=Enter object label
keyChange=Change Box Key
......@@ -107,3 +107,6 @@ undoLastPoint=撤销上个点
autoSaveMode=自动导出标记结果
lockBox=锁定框/解除锁定框
lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态
keyListTitle=关键词列表
keyDialogTip=请输入类型名称
keyChange=更改Box关键字类别
\ No newline at end of file
......@@ -152,7 +152,7 @@ For a new language request, please refer to [Guideline for new language_requests
[1] PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection, detection frame correction and CRNN text recognition. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941).
[2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (arXiv link is coming soon).
[2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (https://arxiv.org/abs/2109.03144).
......@@ -181,16 +181,11 @@ For a new language request, please refer to [Guideline for new language_requests
<a name="language_requests"></a>
## Guideline for New Language Requests
If you want to request a new language support, a PR with 2 following files are needed:
If you want to request a new language support, a PR with 1 following files are needed:
1. In folder [ppocr/utils/dict](./ppocr/utils/dict),
it is necessary to submit the dict text to this path and name it with `{language}_dict.txt` that contains a list of all characters. Please see the format example from other files in that folder.
2. In folder [ppocr/utils/corpus](./ppocr/utils/corpus),
it is necessary to submit the corpus to this path and name it with `{language}_corpus.txt` that contains a list of words in your language.
Maybe, 50000 words per language is necessary at least.
Of course, the more, the better.
If your language has unique elements, please tell me in advance within any way, such as useful links, wikipedia and so on.
More details, please refer to [Multilingual OCR Development Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
......
......@@ -26,35 +26,57 @@ def parse_args():
parser.add_argument(
"--filename", type=str, help="The name of log which need to analysis.")
parser.add_argument(
"--log_with_profiler", type=str, help="The path of train log with profiler")
"--log_with_profiler",
type=str,
help="The path of train log with profiler")
parser.add_argument(
"--profiler_path", type=str, help="The path of profiler timeline log.")
parser.add_argument(
"--keyword", type=str, help="Keyword to specify analysis data")
parser.add_argument(
"--separator", type=str, default=None, help="Separator of different field in log")
"--separator",
type=str,
default=None,
help="Separator of different field in log")
parser.add_argument(
'--position', type=int, default=None, help='The position of data field')
parser.add_argument(
'--range', type=str, default="", help='The range of data field to intercept')
'--range',
type=str,
default="",
help='The range of data field to intercept')
parser.add_argument(
'--base_batch_size', type=int, help='base_batch size on gpu')
parser.add_argument(
'--skip_steps', type=int, default=0, help='The number of steps to be skipped')
'--skip_steps',
type=int,
default=0,
help='The number of steps to be skipped')
parser.add_argument(
'--model_mode', type=int, default=-1, help='Analysis mode, default value is -1')
'--model_mode',
type=int,
default=-1,
help='Analysis mode, default value is -1')
parser.add_argument('--ips_unit', type=str, default=None, help='IPS unit')
parser.add_argument(
'--ips_unit', type=str, default=None, help='IPS unit')
parser.add_argument(
'--model_name', type=str, default=0, help='training model_name, transformer_base')
'--model_name',
type=str,
default=0,
help='training model_name, transformer_base')
parser.add_argument(
'--mission_name', type=str, default=0, help='training mission name')
parser.add_argument(
'--direction_id', type=int, default=0, help='training direction_id')
parser.add_argument(
'--run_mode', type=str, default="sp", help='multi process or single process')
'--run_mode',
type=str,
default="sp",
help='multi process or single process')
parser.add_argument(
'--index', type=int, default=1, help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
'--index',
type=int,
default=1,
help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
parser.add_argument(
'--gpu_num', type=int, default=1, help='nums of training gpus')
args = parser.parse_args()
......@@ -72,7 +94,12 @@ def _is_number(num):
class TimeAnalyzer(object):
def __init__(self, filename, keyword=None, separator=None, position=None, range="-1"):
def __init__(self,
filename,
keyword=None,
separator=None,
position=None,
range="-1"):
if filename is None:
raise Exception("Please specify the filename!")
......@@ -99,7 +126,8 @@ class TimeAnalyzer(object):
# Distil the string from a line.
line = line.strip()
line_words = line.split(self.separator) if self.separator else line.split()
line_words = line.split(
self.separator) if self.separator else line.split()
if args.position:
result = line_words[self.position]
else:
......@@ -108,27 +136,36 @@ class TimeAnalyzer(object):
if line_words[i] == self.keyword:
result = line_words[i + 1]
break
# Distil the result from the picked string.
if not self.range:
result = result[0:]
elif _is_number(self.range):
result = result[0: int(self.range)]
result = result[0:int(self.range)]
else:
result = result[int(self.range.split(":")[0]): int(self.range.split(":")[1])]
result = result[int(self.range.split(":")[0]):int(
self.range.split(":")[1])]
self.records.append(float(result))
except Exception as exc:
print("line is: {}; separator={}; position={}".format(line, self.separator, self.position))
print("line is: {}; separator={}; position={}".format(
line, self.separator, self.position))
print("Extract {} records: separator={}; position={}".format(len(self.records), self.separator, self.position))
print("Extract {} records: separator={}; position={}".format(
len(self.records), self.separator, self.position))
def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None):
def _get_fps(self,
mode,
batch_size,
gpu_num,
avg_of_records,
run_mode,
unit=None):
if mode == -1 and run_mode == 'sp':
assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records
elif mode == -1 and run_mode == 'mp':
assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records #temporarily, not used now
fps = gpu_num * avg_of_records #temporarily, not used now
print("------------this is mp")
elif mode == 0:
# s/step -> samples/s
......@@ -155,12 +192,20 @@ class TimeAnalyzer(object):
return fps, unit
def analysis(self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode='sp', unit=None):
def analysis(self,
batch_size,
gpu_num=1,
skip_steps=0,
mode=-1,
run_mode='sp',
unit=None):
if batch_size <= 0:
print("base_batch_size should larger than 0.")
return 0, ''
if len(self.records) <= skip_steps: # to address the condition which item of log equals to skip_steps
if len(
self.records
) <= skip_steps: # to address the condition which item of log equals to skip_steps
print("no records")
return 0, ''
......@@ -180,16 +225,20 @@ class TimeAnalyzer(object):
skip_max = self.records[i]
avg_of_records = sum_of_records / float(count)
avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps)
avg_of_records_skipped = sum_of_records_skipped / float(count -
skip_steps)
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, run_mode, unit)
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit)
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records,
run_mode, unit)
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num,
avg_of_records_skipped, run_mode, unit)
if mode == -1:
print("average ips of %d steps, skip 0 step:" % count)
print("\tAvg: %.3f %s" % (avg_of_records, fps_unit))
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average ips of %d steps, skip %d steps:" % (count, skip_steps))
print("average ips of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit))
print("\tMin: %.3f %s" % (skip_min, fps_unit))
print("\tMax: %.3f %s" % (skip_max, fps_unit))
......@@ -199,7 +248,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f steps/s" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f steps/s" % avg_of_records_skipped)
print("\tMin: %.3f steps/s" % skip_min)
print("\tMax: %.3f steps/s" % skip_max)
......@@ -209,7 +259,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f s/step" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f s/step" % avg_of_records_skipped)
print("\tMin: %.3f s/step" % skip_min)
print("\tMax: %.3f s/step" % skip_max)
......@@ -236,7 +287,8 @@ if __name__ == "__main__":
if args.gpu_num == 1:
run_info["log_with_profiler"] = args.log_with_profiler
run_info["profiler_path"] = args.profiler_path
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, args.position, args.range)
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator,
args.position, args.range)
run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis(
batch_size=args.base_batch_size,
gpu_num=args.gpu_num,
......@@ -245,29 +297,50 @@ if __name__ == "__main__":
run_mode=args.run_mode,
unit=args.ips_unit)
try:
if int(os.getenv('job_fail_flag')) == 1 or int(run_info["FINAL_RESULT"]) == 0:
if int(os.getenv('job_fail_flag')) == 1 or int(run_info[
"FINAL_RESULT"]) == 0:
run_info["JOB_FAIL_FLAG"] = 1
except:
pass
elif args.index == 3:
run_info["FINAL_RESULT"] = {}
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', None, 3, '').records
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', None, 5).records
records_ct_total = TimeAnalyzer(args.filename, 'Computation time', None, 3, '').records
records_gm_total = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 4, '').records
records_gm_ratio = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 6).records
records_gmas_total = TimeAnalyzer(args.filename, 'GpuMemcpyAsync Calls', None, 4, '').records
records_gms_total = TimeAnalyzer(args.filename, 'GpuMemcpySync Calls', None, 4, '').records
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[0] if records_fo_total else 0
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[0] if records_fo_ratio else 0
run_info["FINAL_RESULT"]["ComputationTime_Total"] = records_ct_total[0] if records_ct_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[0] if records_gm_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[0] if records_gm_ratio else 0
run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = records_gmas_total[0] if records_gmas_total else 0
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[0] if records_gms_total else 0
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead',
None, 3, '').records
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead',
None, 5).records
records_ct_total = TimeAnalyzer(args.filename, 'Computation time',
None, 3, '').records
records_gm_total = TimeAnalyzer(args.filename,
'GpuMemcpy Calls',
None, 4, '').records
records_gm_ratio = TimeAnalyzer(args.filename,
'GpuMemcpy Calls',
None, 6).records
records_gmas_total = TimeAnalyzer(args.filename,
'GpuMemcpyAsync Calls',
None, 4, '').records
records_gms_total = TimeAnalyzer(args.filename,
'GpuMemcpySync Calls',
None, 4, '').records
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[
0] if records_fo_total else 0
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[
0] if records_fo_ratio else 0
run_info["FINAL_RESULT"][
"ComputationTime_Total"] = records_ct_total[
0] if records_ct_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[
0] if records_gm_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[
0] if records_gm_ratio else 0
run_info["FINAL_RESULT"][
"GpuMemcpyAsync_Total"] = records_gmas_total[
0] if records_gmas_total else 0
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[
0] if records_gms_total else 0
else:
print("Not support!")
except Exception:
traceback.print_exc()
print("{}".format(json.dumps(run_info))) # it's required, for the log file path insert to the database
traceback.print_exc()
print("{}".format(json.dumps(run_info))
) # it's required, for the log file path insert to the database
......@@ -58,3 +58,4 @@ source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合
_set_params $@
#_train # 如果只想产出训练log,不解析,可取消注释
_run # 该函数在run_model.sh中,执行时会调用_train; 如果不联调只想要产出训练log可以注掉本行,提交时需打开
......@@ -36,3 +36,4 @@ for model_mode in ${model_mode_list[@]}; do
done
Global:
use_gpu: true
use_xpu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
......@@ -21,7 +21,7 @@ Architecture:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
checkpoints:
Loss:
name: LossFromOutput
......@@ -35,6 +35,7 @@ Optimizer:
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
......@@ -81,7 +82,7 @@ Train:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
num_workers: 8
collate_fn: ListCollator
Eval:
......@@ -118,5 +119,5 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment