Commit 4b214948 authored by zhiminzhang0830's avatar zhiminzhang0830
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into new_branch

parents 917606b4 6e607a0f
......@@ -10,95 +10,65 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#!/usr/bin/env python
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# pyrcc5 -o libs/resources.py resources.qrc
import argparse
import ast
import codecs
import json
import os.path
import platform
import subprocess
import sys
from functools import partial
from collections import defaultdict
import json
import cv2
from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPointF, QProcess
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, \
QFileDialog, QListWidgetItem, QComboBox, QDialog
__dir__ = os.path.dirname(os.path.abspath(__file__))
import numpy as np
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../PaddleOCR')))
sys.path.append("..")
from paddleocr import PaddleOCR
try:
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
# needed for py3+qt4
# Ref:
# http://pyqt.sourceforge.net/Docs/PyQt4/incompatible_apis.html
# http://stackoverflow.com/questions/21217399/pyqt4-qtcore-qvariant-object-instead-of-a-string
if sys.version_info.major >= 3:
import sip
sip.setapi('QVariant', 2)
from PyQt4.QtGui import *
from PyQt4.QtCore import *
from combobox import ComboBox
from libs.constants import *
from libs.utils import *
from libs.labelColor import label_colormap
from libs.settings import Settings
from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR,DEFAULT_LOCK_COLOR
from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR, DEFAULT_LOCK_COLOR
from libs.stringBundle import StringBundle
from libs.canvas import Canvas
from libs.zoomWidget import ZoomWidget
from libs.autoDialog import AutoDialog
from libs.labelDialog import LabelDialog
from libs.colorDialog import ColorDialog
from libs.toolBar import ToolBar
from libs.ustr import ustr
from libs.hashableQListWidgetItem import HashableQListWidgetItem
from libs.editinlist import EditInList
from libs.unique_label_qlist_widget import UniqueLabelQListWidget
from libs.keyDialog import KeyDialog
__appname__ = 'PPOCRLabel'
class WindowMixin(object):
def menu(self, title, actions=None):
menu = self.menuBar().addMenu(title)
if actions:
addActions(menu, actions)
return menu
def toolbar(self, title, actions=None):
toolbar = ToolBar(title)
toolbar.setObjectName(u'%sToolBar' % title)
# toolbar.setOrientation(Qt.Vertical)
toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
if actions:
addActions(toolbar, actions)
self.addToolBar(Qt.LeftToolBarArea, toolbar)
return toolbar
LABEL_COLORMAP = label_colormap()
class MainWindow(QMainWindow, WindowMixin):
class MainWindow(QMainWindow):
FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = list(range(3))
def __init__(self, lang="ch", gpu=False, defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None):
def __init__(self,
lang="ch",
gpu=False,
kie_mode=False,
default_filename=None,
default_predefined_class_file=None,
default_save_dir=None):
super(MainWindow, self).__init__()
self.setWindowTitle(__appname__)
self.setWindowState(Qt.WindowMaximized) # set window max
......@@ -109,14 +79,27 @@ class MainWindow(QMainWindow, WindowMixin):
self.settings.load()
settings = self.settings
self.lang = lang
# Load string bundle for i18n
if lang not in ['ch', 'en']:
lang = 'en'
self.stringBundle = StringBundle.getBundle(localeStr='zh-CN' if lang=='ch' else 'en') # 'en'
self.stringBundle = StringBundle.getBundle(localeStr='zh-CN' if lang == 'ch' else 'en') # 'en'
getStr = lambda strId: self.stringBundle.getString(strId)
self.defaultSaveDir = defaultSaveDir
self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=gpu, lang=lang, show_log=False)
# KIE setting
self.kie_mode = kie_mode
self.key_previous_text = ""
self.existed_key_cls_set = set()
self.key_dialog_tip = getStr('keyDialogTip')
self.defaultSaveDir = default_save_dir
self.ocr = PaddleOCR(use_pdserving=False,
use_angle_cls=True,
det=True,
cls=True,
use_gpu=gpu,
lang=lang,
show_log=False)
if os.path.exists('./data/paddle.png'):
result = self.ocr.ocr('./data/paddle.png', cls=True, det=True)
......@@ -134,7 +117,6 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelFile = None
self.currIndex = 0
# Whether we need to save or not.
self.dirty = False
......@@ -144,7 +126,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.screencast = "https://github.com/PaddlePaddle/PaddleOCR"
# Load predefined classes to the list
self.loadPredefinedClasses(defaultPrefdefClassFile)
self.loadPredefinedClasses(default_predefined_class_file)
# Main widgets and related state.
self.labelDialog = LabelDialog(parent=self, listItem=self.labelHist)
......@@ -160,12 +142,14 @@ class MainWindow(QMainWindow, WindowMixin):
self.PPreader = None
self.autoSaveNum = 5
################# file list ###############
# ================== File List ==================
filelistLayout = QVBoxLayout()
filelistLayout.setContentsMargins(0, 0, 0, 0)
self.fileListWidget = QListWidget()
self.fileListWidget.itemClicked.connect(self.fileitemDoubleClicked)
self.fileListWidget.setIconSize(QSize(25, 25))
filelistLayout = QVBoxLayout()
filelistLayout.setContentsMargins(0, 0, 0, 0)
filelistLayout.addWidget(self.fileListWidget)
self.AutoRecognition = QToolButton()
......@@ -181,15 +165,29 @@ class MainWindow(QMainWindow, WindowMixin):
fileListContainer = QWidget()
fileListContainer.setLayout(filelistLayout)
self.fileListName = getStr('fileList')
self.filedock = QDockWidget(self.fileListName, self)
self.filedock.setObjectName(getStr('files'))
self.filedock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.filedock)
######## Right area ##########
self.fileDock = QDockWidget(self.fileListName, self)
self.fileDock.setObjectName(getStr('files'))
self.fileDock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.fileDock)
# ================== Key List ==================
if self.kie_mode:
# self.keyList = QListWidget()
self.keyList = UniqueLabelQListWidget()
# self.keyList.itemSelectionChanged.connect(self.keyListSelectionChanged)
# self.keyList.itemDoubleClicked.connect(self.editBox)
# self.keyList.itemChanged.connect(self.keyListItemChanged)
self.keyListDockName = getStr('keyListTitle')
self.keyListDock = QDockWidget(self.keyListDockName, self)
self.keyListDock.setWidget(self.keyList)
self.keyListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
filelistLayout.addWidget(self.keyListDock)
# ================== Right Area ==================
listLayout = QVBoxLayout()
listLayout.setContentsMargins(0, 0, 0, 0)
# Buttons
self.editButton = QToolButton()
self.reRecogButton = QToolButton()
self.reRecogButton.setIcon(newIcon('reRec', 30))
......@@ -202,44 +200,44 @@ class MainWindow(QMainWindow, WindowMixin):
self.DelButton = QToolButton()
self.DelButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
leftTopToolBox = QHBoxLayout()
leftTopToolBox.addWidget(self.newButton)
leftTopToolBox.addWidget(self.reRecogButton)
leftTopToolBoxContainer = QWidget()
leftTopToolBoxContainer.setLayout(leftTopToolBox)
listLayout.addWidget(leftTopToolBoxContainer)
lefttoptoolbox = QHBoxLayout()
lefttoptoolbox.addWidget(self.newButton)
lefttoptoolbox.addWidget(self.reRecogButton)
lefttoptoolboxcontainer = QWidget()
lefttoptoolboxcontainer.setLayout(lefttoptoolbox)
listLayout.addWidget(lefttoptoolboxcontainer)
################## label list ####################
# ================== Label List ==================
# Create and add a widget for showing current label items
self.labelList = EditInList()
labelListContainer = QWidget()
labelListContainer.setLayout(listLayout)
#self.labelList.itemActivated.connect(self.labelSelectionChanged)
self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
self.labelList.clicked.connect(self.labelList.item_clicked)
# Connect to itemChanged to detect checkbox changes.
self.labelList.itemChanged.connect(self.labelItemChanged)
self.labelListDock = QDockWidget(getStr('recognitionResult'),self)
self.labelListDockName = getStr('recognitionResult')
self.labelListDock = QDockWidget(self.labelListDockName, self)
self.labelListDock.setWidget(self.labelList)
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
listLayout.addWidget(self.labelListDock)
################## detection box ####################
# ================== Detection Box ==================
self.BoxList = QListWidget()
#self.BoxList.itemActivated.connect(self.boxSelectionChanged)
# self.BoxList.itemActivated.connect(self.boxSelectionChanged)
self.BoxList.itemSelectionChanged.connect(self.boxSelectionChanged)
self.BoxList.itemDoubleClicked.connect(self.editBox)
# Connect to itemChanged to detect checkbox changes.
self.BoxList.itemChanged.connect(self.boxItemChanged)
self.BoxListDock = QDockWidget(getStr('detectionBoxposition'), self)
self.BoxListDockName = getStr('detectionBoxposition')
self.BoxListDock = QDockWidget(self.BoxListDockName, self)
self.BoxListDock.setWidget(self.BoxList)
self.BoxListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
listLayout.addWidget(self.BoxListDock)
############ lower right area ############
# ================== Lower Right Area ==================
leftbtmtoolbox = QHBoxLayout()
leftbtmtoolbox.addWidget(self.SaveButton)
leftbtmtoolbox.addWidget(self.DelButton)
......@@ -251,26 +249,26 @@ class MainWindow(QMainWindow, WindowMixin):
self.dock.setObjectName(getStr('labels'))
self.dock.setWidget(labelListContainer)
# ================== Zoom Bar ==================
self.imageSlider = QSlider(Qt.Horizontal)
self.imageSlider.valueChanged.connect(self.CanvasSizeChange)
self.imageSlider.setMinimum(-9)
self.imageSlider.setMaximum(510)
self.imageSlider.setSingleStep(1)
self.imageSlider.setTickPosition(QSlider.TicksBelow)
self.imageSlider.setTickInterval(1)
########## zoom bar #########
self.imgsplider = QSlider(Qt.Horizontal)
self.imgsplider.valueChanged.connect(self.CanvasSizeChange)
self.imgsplider.setMinimum(-150)
self.imgsplider.setMaximum(150)
self.imgsplider.setSingleStep(1)
self.imgsplider.setTickPosition(QSlider.TicksBelow)
self.imgsplider.setTickInterval(1)
op = QGraphicsOpacityEffect()
op.setOpacity(0.2)
self.imgsplider.setGraphicsEffect(op)
# self.imgsplider.setAttribute(Qt.WA_TranslucentBackground)
self.imgsplider.setStyleSheet("background-color:transparent")
self.imgsliderDock = QDockWidget(getStr('ImageResize'), self)
self.imgsliderDock.setObjectName(getStr('IR'))
self.imgsliderDock.setWidget(self.imgsplider)
self.imgsliderDock.setFeatures(QDockWidget.DockWidgetFloatable)
self.imgsliderDock.setAttribute(Qt.WA_TranslucentBackground)
self.addDockWidget(Qt.RightDockWidgetArea, self.imgsliderDock)
self.imageSlider.setGraphicsEffect(op)
self.imageSlider.setStyleSheet("background-color:transparent")
self.imageSliderDock = QDockWidget(getStr('ImageResize'), self)
self.imageSliderDock.setObjectName(getStr('IR'))
self.imageSliderDock.setWidget(self.imageSlider)
self.imageSliderDock.setFeatures(QDockWidget.DockWidgetFloatable)
self.imageSliderDock.setAttribute(Qt.WA_TranslucentBackground)
self.addDockWidget(Qt.RightDockWidgetArea, self.imageSliderDock)
self.zoomWidget = ZoomWidget()
self.colorDialog = ColorDialog(parent=self)
......@@ -278,13 +276,13 @@ class MainWindow(QMainWindow, WindowMixin):
self.msgBox = QMessageBox()
########## thumbnail #########
# ================== Thumbnail ==================
hlayout = QHBoxLayout()
m = (0, 0, 0, 0)
hlayout.setSpacing(0)
hlayout.setContentsMargins(*m)
self.preButton = QToolButton()
self.preButton.setIcon(newIcon("prev",40))
self.preButton.setIcon(newIcon("prev", 40))
self.preButton.setIconSize(QSize(40, 100))
self.preButton.clicked.connect(self.openPrevImg)
self.preButton.setStyleSheet('border: none;')
......@@ -294,10 +292,10 @@ class MainWindow(QMainWindow, WindowMixin):
self.iconlist.setFlow(QListView.TopToBottom)
self.iconlist.setSpacing(10)
self.iconlist.setIconSize(QSize(50, 50))
self.iconlist.setMovement(False)
self.iconlist.setMovement(QListView.Static)
self.iconlist.setResizeMode(QListView.Adjust)
self.iconlist.itemClicked.connect(self.iconitemDoubleClicked)
self.iconlist.setStyleSheet("background-color:transparent; border: none;")
self.iconlist.setStyleSheet("QListWidget{ background-color:transparent; border: none;}")
self.iconlist.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.nextButton = QToolButton()
self.nextButton.setIcon(newIcon("next", 40))
......@@ -310,12 +308,11 @@ class MainWindow(QMainWindow, WindowMixin):
hlayout.addWidget(self.iconlist)
hlayout.addWidget(self.nextButton)
iconListContainer = QWidget()
iconListContainer.setLayout(hlayout)
iconListContainer.setFixedHeight(100)
########### Canvas ###########
# ================== Canvas ==================
self.canvas = Canvas(parent=self)
self.canvas.zoomRequest.connect(self.zoomRequest)
self.canvas.setDrawingShapeToSquare(settings.get(SETTING_DRAW_SQUARE, False))
......@@ -338,32 +335,17 @@ class MainWindow(QMainWindow, WindowMixin):
centerLayout = QVBoxLayout()
centerLayout.setContentsMargins(0, 0, 0, 0)
centerLayout.addWidget(scroll)
#centerLayout.addWidget(self.icondock)
centerLayout.addWidget(iconListContainer,0,Qt.AlignCenter)
centercontainer = QWidget()
centercontainer.setLayout(centerLayout)
# self.scrolldock = QDockWidget('WorkSpace',self)
# self.scrolldock.setObjectName('WorkSpace')
# self.scrolldock.setWidget(centercontainer)
# self.scrolldock.setFeatures(QDockWidget.NoDockWidgetFeatures)
# orititle = self.scrolldock.titleBarWidget()
# tmpwidget = QWidget()
# self.scrolldock.setTitleBarWidget(tmpwidget)
# del orititle
self.setCentralWidget(centercontainer) #self.scrolldock
self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
# self.filedock.setFeatures(QDockWidget.DockWidgetFloatable)
self.filedock.setFeatures(self.filedock.features() ^ QDockWidget.DockWidgetFloatable)
centerLayout.addWidget(iconListContainer, 0, Qt.AlignCenter)
centerContainer = QWidget()
centerContainer.setLayout(centerLayout)
self.dockFeatures = QDockWidget.DockWidgetClosable | QDockWidget.DockWidgetFloatable
self.dock.setFeatures(self.dock.features() ^ self.dockFeatures)
self.setCentralWidget(centerContainer)
self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
self.filedock.setFeatures(QDockWidget.NoDockWidgetFeatures)
self.dock.setFeatures(QDockWidget.DockWidgetClosable | QDockWidget.DockWidgetFloatable)
self.fileDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
###### Actions #######
# ================== Actions ==================
action = partial(newAction, self)
quit = action(getStr('quit'), self.close,
'Ctrl+Q', 'quit', getStr('quitApp'))
......@@ -372,20 +354,20 @@ class MainWindow(QMainWindow, WindowMixin):
'Ctrl+u', 'open', getStr('openDir'))
open_dataset_dir = action(getStr('openDatasetDir'), self.openDatasetDirDialog,
'Ctrl+p', 'open', getStr('openDatasetDir'), enabled=False)
'Ctrl+p', 'open', getStr('openDatasetDir'), enabled=False)
save = action(getStr('save'), self.saveFile,
'Ctrl+V', 'verify', getStr('saveDetail'), enabled=False)
alcm = action(getStr('choosemodel'), self.autolcm,
'Ctrl+M', 'next', getStr('tipchoosemodel'))
'Ctrl+M', 'next', getStr('tipchoosemodel'))
deleteImg = action(getStr('deleteImg'), self.deleteImg, 'Ctrl+Shift+D', 'close', getStr('deleteImgDetail'),
enabled=True)
resetAll = action(getStr('resetAll'), self.resetAll, None, 'resetall', getStr('resetAllDetail'))
color1 = action(getStr('boxLineColor'), self.chooseColor1,
color1 = action(getStr('boxLineColor'), self.chooseColor,
'Ctrl+L', 'color_line', getStr('boxLineColorDetail'))
createMode = action(getStr('crtBox'), self.setCreateMode,
......@@ -398,7 +380,7 @@ class MainWindow(QMainWindow, WindowMixin):
delete = action(getStr('delBox'), self.deleteSelectedShape,
'Alt+X', 'delete', getStr('delBoxDetail'), enabled=False)
copy = action(getStr('dupBox'), self.copySelectedShape,
'Ctrl+C', 'copy', getStr('dupBoxDetail'),
enabled=False)
......@@ -409,7 +391,6 @@ class MainWindow(QMainWindow, WindowMixin):
showAll = action(getStr('showBox'), partial(self.togglePolygons, True),
'Ctrl+A', 'hide', getStr('showAllBoxDetail'),
enabled=False)
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
......@@ -447,16 +428,17 @@ class MainWindow(QMainWindow, WindowMixin):
self.MANUAL_ZOOM: lambda: 1,
}
# ================== New Actions ==================
edit = action(getStr('editLabel'), self.editLabel,
'Ctrl+E', 'edit', getStr('editLabelDetail'),
enabled=False)
######## New actions #######
AutoRec = action(getStr('autoRecognition'), self.autoRecognition,
'', 'Auto', getStr('autoRecognition'), enabled=False)
'', 'Auto', getStr('autoRecognition'), enabled=False)
reRec = action(getStr('reRecognition'), self.reRecognition,
'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False)
'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False)
singleRere = action(getStr('singleRe'), self.singleRerecognition,
'Ctrl+R', 'reRec', getStr('singleRe'), enabled=False)
......@@ -465,23 +447,26 @@ class MainWindow(QMainWindow, WindowMixin):
'q', 'new', getStr('creatPolygon'), enabled=True)
saveRec = action(getStr('saveRec'), self.saveRecResult,
'', 'save', getStr('saveRec'), enabled=False)
'', 'save', getStr('saveRec'), enabled=False)
saveLabel = action(getStr('saveLabel'), self.saveLabelFile, #
'Ctrl+S', 'save', getStr('saveLabel'), enabled=False)
saveLabel = action(getStr('saveLabel'), self.saveLabelFile, #
'Ctrl+S', 'save', getStr('saveLabel'), enabled=False)
undoLastPoint = action(getStr("undoLastPoint"), self.canvas.undoLastPoint,
'Ctrl+Z', "undo", getStr("undoLastPoint"), enabled=False)
rotateLeft = action(getStr("rotateLeft"), partial(self.rotateImgAction,1),
'Ctrl+Alt+L', "rotateLeft", getStr("rotateLeft"), enabled=False)
rotateLeft = action(getStr("rotateLeft"), partial(self.rotateImgAction, 1),
'Ctrl+Alt+L', "rotateLeft", getStr("rotateLeft"), enabled=False)
rotateRight = action(getStr("rotateRight"), partial(self.rotateImgAction,-1),
'Ctrl+Alt+R', "rotateRight", getStr("rotateRight"), enabled=False)
rotateRight = action(getStr("rotateRight"), partial(self.rotateImgAction, -1),
'Ctrl+Alt+R', "rotateRight", getStr("rotateRight"), enabled=False)
undo = action(getStr("undo"), self.undoShapeEdit,
'Ctrl+Z', "undo", getStr("undo"), enabled=False)
change_cls = action(getStr("keyChange"), self.change_box_key,
'Ctrl+B', "edit", getStr("keyChange"), enabled=False)
lock = action(getStr("lockBox"), self.lockSelectedShape,
None, "lock", getStr("lockBoxDetail"),
enabled=False)
......@@ -495,7 +480,7 @@ class MainWindow(QMainWindow, WindowMixin):
# self.preButton.setDefaultAction(openPrevImg)
# self.nextButton.setDefaultAction(openNextImg)
############# Zoom layout ##############
# ================== Zoom layout ==================
zoomLayout = QHBoxLayout()
zoomLayout.addStretch()
self.zoominButton = QToolButton()
......@@ -522,14 +507,12 @@ class MainWindow(QMainWindow, WindowMixin):
icon='color', tip=getStr('shapeFillColorDetail'),
enabled=False)
# Label list context menu.
labelMenu = QMenu()
addActions(labelMenu, (edit, delete))
self.labelList.setContextMenuPolicy(Qt.CustomContextMenu)
self.labelList.customContextMenuRequested.connect(
self.popLabelListMenu)
self.labelList.customContextMenuRequested.connect(self.popLabelListMenu)
# Draw squares/rectangles
self.drawSquaresOption = QAction(getStr('drawSquares'), self)
......@@ -538,39 +521,37 @@ class MainWindow(QMainWindow, WindowMixin):
self.drawSquaresOption.triggered.connect(self.toogleDrawSquare)
# Store actions for further handling.
self.actions = struct(save=save, resetAll=resetAll, deleteImg=deleteImg,
self.actions = struct(save=save, resetAll=resetAll, deleteImg=deleteImg,
lineColor=color1, create=create, delete=delete, edit=edit, copy=copy,
saveRec=saveRec, singleRere=singleRere,AutoRec=AutoRec,reRec=reRec,
saveRec=saveRec, singleRere=singleRere, AutoRec=AutoRec, reRec=reRec,
createMode=createMode, editMode=editMode,
shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
fitWindow=fitWindow, fitWidth=fitWidth,
zoomActions=zoomActions, saveLabel=saveLabel,
undo=undo, undoLastPoint=undoLastPoint,open_dataset_dir=open_dataset_dir,
rotateLeft=rotateLeft,rotateRight=rotateRight,lock=lock,
fileMenuActions=(
opendir, open_dataset_dir, saveLabel, resetAll, quit),
zoomActions=zoomActions, saveLabel=saveLabel, change_cls=change_cls,
undo=undo, undoLastPoint=undoLastPoint, open_dataset_dir=open_dataset_dir,
rotateLeft=rotateLeft, rotateRight=rotateRight, lock=lock,
fileMenuActions=(opendir, open_dataset_dir, saveLabel, resetAll, quit),
beginner=(), advanced=(),
editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint,
None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption,lock),
beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,lock),
editMenu=(createpoly, edit, copy, delete, singleRere, None, undo, undoLastPoint,
None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption, lock),
beginnerContext=(
create, edit, copy, delete, singleRere, rotateLeft, rotateRight, lock, change_cls),
advancedContext=(createMode, editMode, edit, copy,
delete, shapeLineColor, shapeFillColor),
onLoadActive=(
create, createMode, editMode),
onLoadActive=(create, createMode, editMode),
onShapesPresent=(hideAll, showAll))
# menus
self.menus = struct(
file=self.menu('&'+getStr('mfile')),
edit=self.menu('&'+getStr('medit')),
view=self.menu('&'+getStr('mview')),
file=self.menu('&' + getStr('mfile')),
edit=self.menu('&' + getStr('medit')),
view=self.menu('&' + getStr('mview')),
autolabel=self.menu('&PaddleOCR'),
help=self.menu('&'+getStr('mhelp')),
help=self.menu('&' + getStr('mhelp')),
recentFiles=QMenu('Open &Recent'),
labelList=labelMenu)
self.lastLabel = None
# Add option to enable/disable labels being displayed at the top of bounding boxes
self.displayLabelOption = QAction(getStr('displayLabel'), self)
......@@ -591,33 +572,30 @@ class MainWindow(QMainWindow, WindowMixin):
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
addActions(self.menus.file,
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg,
quit))
addActions(self.menus.help, (showKeys,showSteps, showInfo))
addActions(self.menus.help, (showKeys, showSteps, showInfo))
addActions(self.menus.view, (
self.displayLabelOption, self.labelDialogOption,
None,
None,
hideAll, showAll, None,
zoomIn, zoomOut, zoomOrg, None,
fitWindow, fitWidth))
addActions(self.menus.autolabel, (AutoRec, reRec, alcm, None, help)) #
addActions(self.menus.autolabel, (AutoRec, reRec, alcm, None, help))
self.menus.file.aboutToShow.connect(self.updateFileMenu)
# Custom context menu for the canvas widget:
addActions(self.canvas.menus[0], self.actions.beginnerContext)
#addActions(self.canvas.menus[1], (
# action('&Copy here', self.copyShape),
# action('&Move here', self.moveShape)))
self.statusBar().showMessage('%s started.' % __appname__)
self.statusBar().show()
# Application state.
self.image = QImage()
self.filePath = ustr(defaultFilename)
self.filePath = ustr(default_filename)
self.lastOpenDir = None
self.recentFiles = []
self.maxRecent = 7
......@@ -628,7 +606,7 @@ class MainWindow(QMainWindow, WindowMixin):
# Add Chris
self.difficult = False
## Fix the compatible issue for qt4 and qt5. Convert the QStringList to python list
# Fix the compatible issue for qt4 and qt5. Convert the QStringList to python list
if settings.get(SETTING_RECENT_FILES):
if have_qstring():
recentFileQStringList = settings.get(SETTING_RECENT_FILES)
......@@ -657,7 +635,6 @@ class MainWindow(QMainWindow, WindowMixin):
# Add chris
Shape.difficult = self.difficult
# ADD:
# Populate the File menu dynamically.
self.updateFileMenu()
......@@ -668,6 +645,8 @@ class MainWindow(QMainWindow, WindowMixin):
elif self.filePath:
self.queueEvent(partial(self.loadFile, self.filePath or ""))
self.keyDialog = None
# Callbacks:
self.zoomWidget.valueChanged.connect(self.paintCanvas)
......@@ -681,6 +660,12 @@ class MainWindow(QMainWindow, WindowMixin):
if self.filePath and os.path.isdir(self.filePath):
self.openDirDialog(dirpath=self.filePath, silent=True)
def menu(self, title, actions=None):
menu = self.menuBar().addMenu(title)
if actions:
addActions(menu, actions)
return menu
def keyReleaseEvent(self, event):
if event.key() == Qt.Key_Control:
self.canvas.setDrawingShapeToSquare(False)
......@@ -690,11 +675,9 @@ class MainWindow(QMainWindow, WindowMixin):
# Draw rectangle if Ctrl is pressed
self.canvas.setDrawingShapeToSquare(True)
def noShapes(self):
return not self.itemsToShapes
def populateModeActions(self):
self.canvas.menus[0].clear()
addActions(self.canvas.menus[0], self.actions.beginnerContext)
......@@ -702,7 +685,6 @@ class MainWindow(QMainWindow, WindowMixin):
actions = (self.actions.create,) # if self.beginner() else (self.actions.createMode, self.actions.editMode)
addActions(self.menus.edit, actions + self.actions.editMenu)
def setDirty(self):
self.dirty = True
self.actions.save.setEnabled(True)
......@@ -816,10 +798,11 @@ class MainWindow(QMainWindow, WindowMixin):
def rotateImgWarn(self):
if self.lang == 'ch':
self.msgBox.warning (self, "提示", "\n 该图片已经有标注框,旋转操作会打乱标注,建议清除标注框后旋转。")
self.msgBox.warning(self, "提示", "\n 该图片已经有标注框,旋转操作会打乱标注,建议清除标注框后旋转。")
else:
self.msgBox.warning (self, "Warn", "\n The picture already has a label box, and rotation will disrupt the label.\
It is recommended to clear the label box and rotate it.")
self.msgBox.warning(self, "Warn", "\n The picture already has a label box, "
"and rotation will disrupt the label. "
"It is recommended to clear the label box and rotate it.")
def rotateImgAction(self, k=1, _value=False):
......@@ -894,14 +877,13 @@ class MainWindow(QMainWindow, WindowMixin):
self.setDirty()
self.updateComboBox()
######## detection box related functions #######
# =================== detection box related functions ===================
def boxItemChanged(self, item):
shape = self.itemsToShapesbox[item]
box = ast.literal_eval(item.text())
# print('shape in labelItemChanged is',shape.points)
if box != [(p.x(), p.y()) for p in shape.points]:
if box != [(int(p.x()), int(p.y())) for p in shape.points]:
# shape.points = box
shape.points = [QPointF(p[0], p[1]) for p in box]
......@@ -909,7 +891,7 @@ class MainWindow(QMainWindow, WindowMixin):
# shape.line_color = generateColorByText(shape.label)
self.setDirty()
else: # User probably changed item visibility
self.canvas.setShapeVisible(shape, True)#item.checkState() == Qt.Checked
self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
def editBox(self): # ADD
if not self.canvas.editing():
......@@ -959,11 +941,10 @@ class MainWindow(QMainWindow, WindowMixin):
def indexTo5Files(self, currIndex):
if currIndex < 2:
return self.mImgList[:5]
elif currIndex > len(self.mImgList)-3:
elif currIndex > len(self.mImgList) - 3:
return self.mImgList[-5:]
else:
return self.mImgList[currIndex - 2 : currIndex + 3]
return self.mImgList[currIndex - 2: currIndex + 3]
# Tzutalin 20160906 : Add file list and dock to move faster
def fileitemDoubleClicked(self, item=None):
......@@ -983,9 +964,8 @@ class MainWindow(QMainWindow, WindowMixin):
self.loadFile(filename)
def CanvasSizeChange(self):
if len(self.mImgList) > 0:
self.zoomWidget.setValue(self.zoomWidgetValue + self.imgsplider.value())
if len(self.mImgList) > 0 and self.imageSlider.hasFocus():
self.zoomWidget.setValue(self.imageSlider.value())
def shapeSelectionChanged(self, selected_shapes):
self._noSelectionSlot = True
......@@ -998,9 +978,15 @@ class MainWindow(QMainWindow, WindowMixin):
self.shapesToItems[shape].setSelected(True)
self.shapesToItemsbox[shape].setSelected(True)
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
self.BoxList.scrollToItem(self.currentBox())
if self.kie_mode:
if len(self.canvas.selectedShapes) == 1 and self.keyList.count() > 0:
selected_key_item_row = self.keyList.findItemsByLabel(self.canvas.selectedShapes[0].key_cls,
get_row=True)
self.keyList.setCurrentRow(selected_key_item_row)
self._noSelectionSlot = False
n_selected = len(selected_shapes)
self.actions.singleRere.setEnabled(n_selected)
......@@ -1008,6 +994,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.copy.setEnabled(n_selected)
self.actions.edit.setEnabled(n_selected == 1)
self.actions.lock.setEnabled(n_selected)
self.actions.change_cls.setEnabled(n_selected)
def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked()
......@@ -1030,6 +1017,10 @@ class MainWindow(QMainWindow, WindowMixin):
action.setEnabled(True)
self.updateComboBox()
# update show counting
self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
def remLabels(self, shapes):
if shapes is None:
# print('rm empty label')
......@@ -1050,8 +1041,8 @@ class MainWindow(QMainWindow, WindowMixin):
def loadLabels(self, shapes):
s = []
for label, points, line_color, fill_color, difficult in shapes:
shape = Shape(label=label,line_color=line_color)
for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points:
# Ensure the labels are within the bounds of the image. If not, fix them.
......@@ -1061,25 +1052,15 @@ class MainWindow(QMainWindow, WindowMixin):
shape.addPoint(QPointF(x, y))
shape.difficult = difficult
#shape.locked = False
# shape.locked = False
shape.close()
s.append(shape)
# if line_color:
# shape.line_color = QColor(*line_color)
# else:
# shape.line_color = generateColorByText(label)
#
# if fill_color:
# shape.fill_color = QColor(*fill_color)
# else:
# shape.fill_color = generateColorByText(label)
self._update_shape_color(shape)
self.addLabel(shape)
self.updateComboBox()
self.canvas.loadShapes(s)
def singleLabel(self, shape):
if shape is None:
......@@ -1115,14 +1096,16 @@ class MainWindow(QMainWindow, WindowMixin):
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
difficult=s.difficult,
key_cls=s.key_cls) # bool
shapes = [] if mode == 'Auto' else \
[format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
if mode == 'Auto':
shapes = []
else:
shapes = [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
# Can add differrent annotation formats here
for box in self.result_dic :
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], "difficult": False, "key_cls": "None"}
if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)
......@@ -1130,7 +1113,8 @@ class MainWindow(QMainWindow, WindowMixin):
try:
trans_dic = []
for box in shapes:
trans_dic.append({"transcription": box['label'], "points": box['points'], 'difficult': box['difficult']})
trans_dic.append({"transcription": box['label'], "points": box['points'],
"difficult": box['difficult'], "key_cls": box['key_cls']})
self.PPlabel[annotationFilePath] = trans_dic
if mode == 'Auto':
self.Cachelabel[annotationFilePath] = trans_dic
......@@ -1148,8 +1132,7 @@ class MainWindow(QMainWindow, WindowMixin):
for shape in self.canvas.copySelectedShape():
self.addLabel(shape)
# fix copy and delete
#self.shapeSelectionChanged(True)
# self.shapeSelectionChanged(True)
def labelSelectionChanged(self):
if self._noSelectionSlot:
......@@ -1163,10 +1146,9 @@ class MainWindow(QMainWindow, WindowMixin):
else:
self.canvas.deSelectShape()
def boxSelectionChanged(self):
if self._noSelectionSlot:
#self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter)
# self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter)
return
if self.canvas.editing():
selected_shapes = []
......@@ -1177,7 +1159,6 @@ class MainWindow(QMainWindow, WindowMixin):
else:
self.canvas.deSelectShape()
def labelItemChanged(self, item):
shape = self.itemsToShapes[item]
label = item.text()
......@@ -1185,7 +1166,7 @@ class MainWindow(QMainWindow, WindowMixin):
shape.label = item.text()
# shape.line_color = generateColorByText(shape.label)
self.setDirty()
elif not ((item.checkState()== Qt.Unchecked) ^ (not shape.difficult)):
elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
shape.difficult = True if item.checkState() == Qt.Unchecked else False
self.setDirty()
else: # User probably changed item visibility
......@@ -1199,8 +1180,7 @@ class MainWindow(QMainWindow, WindowMixin):
position MUST be in global coordinates.
"""
if len(self.labelHist) > 0:
self.labelDialog = LabelDialog(
parent=self, listItem=self.labelHist)
self.labelDialog = LabelDialog(parent=self, listItem=self.labelHist)
if value:
text = self.labelDialog.popUp(text=self.prevLabelText)
......@@ -1210,8 +1190,22 @@ class MainWindow(QMainWindow, WindowMixin):
if text is not None:
self.prevLabelText = self.stringBundle.getString('tempLabel')
# generate_color = generateColorByText(text)
shape = self.canvas.setLastLabel(text, None, None)#generate_color, generate_color
shape = self.canvas.setLastLabel(text, None, None, None) # generate_color, generate_color
if self.kie_mode:
key_text, _ = self.keyDialog.popUp(self.key_previous_text)
if key_text is not None:
shape = self.canvas.setLastLabel(text, None, None, key_text) # generate_color, generate_color
self.key_previous_text = key_text
if not self.keyList.findItemsByLabel(key_text):
item = self.keyList.createItemFromLabel(key_text)
self.keyList.addItem(item)
rgb = self._get_rgb_by_label(key_text, self.kie_mode)
self.keyList.setItemLabel(item, key_text, rgb)
self._update_shape_color(shape)
self.keyDialog.addLabelHistory(key_text)
self.addLabel(shape)
if self.beginner(): # Switch to edit mode.
self.canvas.setEditing(True)
......@@ -1226,6 +1220,25 @@ class MainWindow(QMainWindow, WindowMixin):
# self.canvas.undoLastLine()
self.canvas.resetAllLines()
def _update_shape_color(self, shape):
r, g, b = self._get_rgb_by_label(shape.key_cls, self.kie_mode)
shape.line_color = QColor(r, g, b)
shape.vertex_fill_color = QColor(r, g, b)
shape.hvertex_fill_color = QColor(255, 255, 255)
shape.fill_color = QColor(r, g, b, 128)
shape.select_line_color = QColor(255, 255, 255)
shape.select_fill_color = QColor(r, g, b, 155)
def _get_rgb_by_label(self, label, kie_mode):
shift_auto_shape_color = 2 # use for random color
if kie_mode and label != "None":
item = self.keyList.findItemsByLabel(label)[0]
label_id = self.keyList.indexFromItem(item).row() + 1
label_id += shift_auto_shape_color
return LABEL_COLORMAP[label_id % len(LABEL_COLORMAP)]
else:
return (0, 255, 0)
def scrollRequest(self, delta, orientation):
units = - delta / (8 * 15)
bar = self.scrollBars[orientation]
......@@ -1239,6 +1252,7 @@ class MainWindow(QMainWindow, WindowMixin):
def addZoom(self, increment=10):
self.setZoom(self.zoomWidget.value() + increment)
self.imageSlider.setValue(self.zoomWidget.value() + increment) # set zoom slider value
def zoomRequest(self, delta):
# get the current scrollbar positions
......@@ -1324,17 +1338,16 @@ class MainWindow(QMainWindow, WindowMixin):
# unicodeFilePath = os.path.abspath(unicodeFilePath)
# Tzutalin 20160906 : Add file list and dock to move faster
# Highlight the file item
if unicodeFilePath and self.fileListWidget.count() > 0:
if unicodeFilePath in self.mImgList:
index = self.mImgList.index(unicodeFilePath)
fileWidgetItem = self.fileListWidget.item(index)
print('unicodeFilePath is', unicodeFilePath)
fileWidgetItem.setSelected(True)
###
self.iconlist.clear()
self.additems5(None)
for i in range(5):
item_tooltip = self.iconlist.item(i).toolTip()
# print(i,"---",item_tooltip)
......@@ -1385,7 +1398,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.showBoundingBoxFromPPlabel(filePath)
self.setWindowTitle(__appname__ + ' ' + filePath)
# Default : select last item if there is at least one item
if self.labelList.count():
self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
......@@ -1394,9 +1407,11 @@ class MainWindow(QMainWindow, WindowMixin):
# show file list image count
select_indexes = self.fileListWidget.selectedIndexes()
if len(select_indexes) > 0:
self.filedock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
self.fileDock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
f"/{self.fileListWidget.count()})")
# update show counting
self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
self.canvas.setFocus(True)
return True
......@@ -1405,24 +1420,23 @@ class MainWindow(QMainWindow, WindowMixin):
def showBoundingBoxFromPPlabel(self, filePath):
width, height = self.image.width(), self.image.height()
imgidx = self.getImglabelidx(filePath)
shapes =[]
#box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
shapes = []
# box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
# four corner coordinates of the shapes to the height and width of the image
for box in self.canvas.lockedShapes:
if self.canvas.isInTheSameImage:
shapes.append((box['transcription'], [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
shapes.append((box['transcription'], [[s[0] * width, s[1] * height] for s in box['ratio']],
DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult']))
else:
shapes.append(('锁定框:待检测', [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
shapes.append(('锁定框:待检测', [[s[0] * width, s[1] * height] for s in box['ratio']],
DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult']))
if imgidx in self.PPlabel.keys():
for box in self.PPlabel[imgidx]:
shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
shapes.append((box['transcription'], box['points'], None, box['key_cls'], box['difficult']))
self.loadLabels(shapes)
self.canvas.verified = False
def validFilestate(self, filePath):
if filePath not in self.fileStatedict.keys():
return None
......@@ -1433,7 +1447,7 @@ class MainWindow(QMainWindow, WindowMixin):
def resizeEvent(self, event):
if self.canvas and not self.image.isNull() \
and self.zoomMode != self.MANUAL_ZOOM:
and self.zoomMode != self.MANUAL_ZOOM:
self.adjustScale()
super(MainWindow, self).resizeEvent(event)
......@@ -1451,7 +1465,7 @@ class MainWindow(QMainWindow, WindowMixin):
"""Figure out the size of the pixmap in order to fit the main widget."""
e = 2.0 # So that no scrollbars are generated.
w1 = self.centralWidget().width() - e
h1 = self.centralWidget().height() - e -110
h1 = self.centralWidget().height() - e - 110
a1 = w1 / h1
# Calculate a new scale value based on the pixmap's aspect ratio.
w2 = self.canvas.pixmap.width() - 0.0
......@@ -1502,7 +1516,7 @@ class MainWindow(QMainWindow, WindowMixin):
def loadRecent(self, filename):
if self.mayContinue():
print(filename,"======")
print(filename, "======")
self.loadFile(filename)
def scanAllImages(self, folderPath):
......@@ -1517,8 +1531,6 @@ class MainWindow(QMainWindow, WindowMixin):
natural_sort(images, key=lambda x: x.lower())
return images
def openDirDialog(self, _value=False, dirpath=None, silent=False):
if not self.mayContinue():
return
......@@ -1530,15 +1542,15 @@ class MainWindow(QMainWindow, WindowMixin):
defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
if silent != True:
targetDirPath = ustr(QFileDialog.getExistingDirectory(self,
'%s - Open Directory' % __appname__,
defaultOpenDirPath,
QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks))
'%s - Open Directory' % __appname__,
defaultOpenDirPath,
QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks))
else:
targetDirPath = ustr(defaultOpenDirPath)
self.lastOpenDir = targetDirPath
self.importDirImages(targetDirPath)
def openDatasetDirDialog(self,):
def openDatasetDirDialog(self):
if self.lastOpenDir and os.path.exists(self.lastOpenDir):
if platform.system() == 'Windows':
os.startfile(self.lastOpenDir)
......@@ -1550,12 +1562,46 @@ class MainWindow(QMainWindow, WindowMixin):
if self.lang == 'ch':
self.msgBox.warning(self, "提示", "\n 原文件夹已不存在,请从新选择数据集路径!")
else:
self.msgBox.warning(self, "Warn", "\n The original folder no longer exists, please choose the data set path again!")
self.msgBox.warning(self, "Warn",
"\n The original folder no longer exists, please choose the data set path again!")
self.actions.open_dataset_dir.setEnabled(False)
defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
def importDirImages(self, dirpath, isDelete = False):
def init_key_list(self, label_dict):
if not self.kie_mode:
return
# load key_cls
for image, info in label_dict.items():
for box in info:
if "key_cls" not in box:
continue
self.existed_key_cls_set.add(box["key_cls"])
if len(self.existed_key_cls_set) > 0:
for key_text in self.existed_key_cls_set:
if not self.keyList.findItemsByLabel(key_text):
item = self.keyList.createItemFromLabel(key_text)
self.keyList.addItem(item)
rgb = self._get_rgb_by_label(key_text, self.kie_mode)
self.keyList.setItemLabel(item, key_text, rgb)
if self.keyDialog is None:
# key list dialog
self.keyDialog = KeyDialog(
text=self.key_dialog_tip,
parent=self,
labels=self.existed_key_cls_set,
sort_labels=True,
show_text_field=True,
completion="startswith",
fit_to_content={'column': True, 'row': False},
flags=None
)
else:
self.keyDialog.labelList.addItems(self.existed_key_cls_set)
def importDirImages(self, dirpath, isDelete=False):
if not self.mayContinue() or not dirpath:
return
if self.defaultSaveDir and self.defaultSaveDir != dirpath:
......@@ -1563,16 +1609,18 @@ class MainWindow(QMainWindow, WindowMixin):
if not isDelete:
self.loadFilestate(dirpath)
self.PPlabelpath = dirpath+ '/Label.txt'
self.PPlabelpath = dirpath + '/Label.txt'
self.PPlabel = self.loadLabelFile(self.PPlabelpath)
self.Cachelabelpath = dirpath + '/Cache.cach'
self.Cachelabel = self.loadLabelFile(self.Cachelabelpath)
if self.Cachelabel:
self.PPlabel = dict(self.Cachelabel, **self.PPlabel)
self.init_key_list(self.PPlabel)
self.lastOpenDir = dirpath
self.dirname = dirpath
self.defaultSaveDir = dirpath
self.statusBar().showMessage('%s started. Annotation will be saved to %s' %
(__appname__, self.defaultSaveDir))
......@@ -1607,7 +1655,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.rotateRight.setEnabled(True)
self.fileListWidget.setCurrentRow(0) # set list index to first
self.filedock.setWindowTitle(self.fileListName + f" (1/{self.fileListWidget.count()})") # show image count
self.fileDock.setWindowTitle(self.fileListName + f" (1/{self.fileListWidget.count()})") # show image count
def openPrevImg(self, _value=False):
if len(self.mImgList) <= 0:
......@@ -1643,7 +1691,7 @@ class MainWindow(QMainWindow, WindowMixin):
else:
self.mImgList5 = self.indexTo5Files(currIndex)
if filename:
print('file name in openNext is ',filename)
print('file name in openNext is ', filename)
self.loadFile(filename)
def updateFileListIcon(self, filename):
......@@ -1655,30 +1703,6 @@ class MainWindow(QMainWindow, WindowMixin):
imgidx = self.getImglabelidx(self.filePath)
self._saveFile(imgidx, mode=mode)
def saveFileAs(self, _value=False):
assert not self.image.isNull(), "cannot save empty image"
self._saveFile(self.saveFileDialog())
def saveFileDialog(self, removeExt=True):
caption = '%s - Choose File' % __appname__
filters = 'File (*%s)' % LabelFile.suffix
openDialogPath = self.currentPath()
dlg = QFileDialog(self, caption, openDialogPath, filters)
dlg.setDefaultSuffix(LabelFile.suffix[1:])
dlg.setAcceptMode(QFileDialog.AcceptSave)
filenameWithoutExtension = os.path.splitext(self.filePath)[0]
dlg.selectFile(filenameWithoutExtension)
dlg.setOption(QFileDialog.DontUseNativeDialog, False)
if dlg.exec_():
fullFilePath = ustr(dlg.selectedFiles()[0])
if removeExt:
return os.path.splitext(fullFilePath)[0] # Return file path without the extension.
else:
return fullFilePath
return ''
def saveLockedShapes(self):
self.canvas.lockedShapes = []
self.canvas.selectedShapes = []
......@@ -1691,7 +1715,6 @@ class MainWindow(QMainWindow, WindowMixin):
self.canvas.selectedShapes.remove(s)
self.canvas.shapes.remove(s)
def _saveFile(self, annotationFilePath, mode='Manual'):
if len(self.canvas.lockedShapes) != 0:
self.saveLockedShapes()
......@@ -1701,9 +1724,9 @@ class MainWindow(QMainWindow, WindowMixin):
img = cv2.imread(self.filePath)
width, height = self.image.width(), self.image.height()
for shape in self.canvas.lockedShapes:
box = [[int(p[0]*width), int(p[1]*height)] for p in shape['ratio']]
box = [[int(p[0] * width), int(p[1] * height)] for p in shape['ratio']]
assert len(box) == 4
result = [(shape['transcription'],1)]
result = [(shape['transcription'], 1)]
result.insert(0, box)
self.result_dic_locked.append(result)
self.result_dic += self.result_dic_locked
......@@ -1717,7 +1740,7 @@ class MainWindow(QMainWindow, WindowMixin):
item.setIcon(newIcon('done'))
self.fileStatedict[self.filePath] = 1
if len(self.fileStatedict)%self.autoSaveNum ==0:
if len(self.fileStatedict) % self.autoSaveNum == 0:
self.saveFilestate()
self.savePPlabel(mode='Auto')
......@@ -1750,8 +1773,8 @@ class MainWindow(QMainWindow, WindowMixin):
if platform.system() == 'Windows':
from win32com.shell import shell, shellcon
shell.SHFileOperation((0, shellcon.FO_DELETE, deletePath, None,
shellcon.FOF_SILENT | shellcon.FOF_ALLOWUNDO | shellcon.FOF_NOCONFIRMATION,
None, None))
shellcon.FOF_SILENT | shellcon.FOF_ALLOWUNDO | shellcon.FOF_NOCONFIRMATION,
None, None))
# linux
elif platform.system() == 'Linux':
cmd = 'trash ' + deletePath
......@@ -1814,7 +1837,7 @@ class MainWindow(QMainWindow, WindowMixin):
def currentPath(self):
return os.path.dirname(self.filePath) if self.filePath else '.'
def chooseColor1(self):
def chooseColor(self):
color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
default=DEFAULT_LINE_COLOR)
if color:
......@@ -1831,6 +1854,8 @@ class MainWindow(QMainWindow, WindowMixin):
if self.noShapes():
for action in self.actions.onShapesPresent:
action.setEnabled(False)
self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
def chshapeLineColor(self):
color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
......@@ -1867,7 +1892,6 @@ class MainWindow(QMainWindow, WindowMixin):
else:
self.labelHist.append(line)
def togglePaintLabelsOption(self):
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
......@@ -1896,7 +1920,7 @@ class MainWindow(QMainWindow, WindowMixin):
prelen = lentoken // 2
bfilename = prelen * " " + pfilename + (lentoken - prelen) * " "
# item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation)),filename[:10])
item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)),pfilename)
item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)), pfilename)
# item.setForeground(QBrush(Qt.white))
item.setToolTip(file)
self.iconlist.addItem(item)
......@@ -1908,7 +1932,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.iconlist.setMinimumWidth(owidth + 50)
def getImglabelidx(self, filePath):
if platform.system()=='Windows':
if platform.system() == 'Windows':
spliter = '\\'
else:
spliter = '/'
......@@ -1923,13 +1947,14 @@ class MainWindow(QMainWindow, WindowMixin):
self.autoDialog = AutoDialog(parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList))
self.autoDialog.popUp()
self.currIndex = len(self.mImgList) - 1
self.loadFile(self.filePath) # ADD
self.loadFile(self.filePath) # ADD
self.haveAutoReced = True
self.AutoRecognition.setEnabled(False)
self.actions.AutoRec.setEnabled(False)
self.setDirty()
self.saveCacheLabel()
self.init_key_list(self.Cachelabel)
def reRecognition(self):
img = cv2.imread(self.filePath)
......@@ -1959,24 +1984,27 @@ class MainWindow(QMainWindow, WindowMixin):
print('Can not recognise the box')
if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
self.result_dic_locked.append([box,(self.noLabelText,0)])
self.result_dic_locked.append([box, (self.noLabelText, 0)])
else:
self.result_dic.append([box,(self.noLabelText,0)])
self.result_dic.append([box, (self.noLabelText, 0)])
try:
if self.noLabelText == shape.label or result[1][0] == shape.label:
print('label no change')
else:
rec_flag += 1
except IndexError as e:
print('Can not recognise the box')
if (len(self.result_dic) > 0 and rec_flag > 0)or self.canvas.lockedShapes:
self.canvas.isInTheSameImage = True
print('Can not recognise the box')
if (len(self.result_dic) > 0 and rec_flag > 0) or self.canvas.lockedShapes:
self.canvas.isInTheSameImage = True
self.saveFile(mode='Auto')
self.loadFile(self.filePath)
self.canvas.isInTheSameImage = False
self.setDirty()
elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0:
QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
if self.lang == 'ch':
QMessageBox.information(self, "Information", "识别结果保持一致!")
else:
QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
else:
print('Can not recgonise in ', self.filePath)
else:
......@@ -2041,7 +2069,6 @@ class MainWindow(QMainWindow, WindowMixin):
self.AutoRecognition.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
def modelChoose(self):
print(self.comboBox.currentText())
lg_idx = {'Chinese & English': 'ch', 'English': 'en', 'French': 'french', 'German': 'german',
......@@ -2068,14 +2095,12 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.saveLabel.setEnabled(True)
self.actions.saveRec.setEnabled(True)
def saveFilestate(self):
with open(self.fileStatepath, 'w', encoding='utf-8') as f:
for key in self.fileStatedict:
f.write(key + '\t')
f.write(str(self.fileStatedict[key]) + '\n')
def loadLabelFile(self, labelpath):
labeldict = {}
if not os.path.exists(labelpath):
......@@ -2094,8 +2119,7 @@ class MainWindow(QMainWindow, WindowMixin):
labeldict[file] = []
return labeldict
def savePPlabel(self,mode='Manual'):
def savePPlabel(self, mode='Manual'):
savedfile = [self.getImglabelidx(i) for i in self.fileStatedict.keys()]
with open(self.PPlabelpath, 'w', encoding='utf-8') as f:
for key in self.PPlabel:
......@@ -2137,19 +2161,22 @@ class MainWindow(QMainWindow, WindowMixin):
try:
img = cv2.imread(key)
for i, label in enumerate(self.PPlabel[idx]):
if label['difficult']: continue
if label['difficult']:
continue
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
cv2.imwrite(crop_img_dir+img_name, img_crop)
f.write('crop_img/'+ img_name + '\t')
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_' + str(i) + '.jpg'
cv2.imwrite(crop_img_dir + img_name, img_crop)
f.write('crop_img/' + img_name + '\t')
f.write(label['transcription'] + '\n')
except Exception as e:
ques_img.append(key)
print("Can not read image ",e)
print("Can not read image ", e)
if ques_img:
QMessageBox.information(self, "Information", "The following images can not be saved, "
"please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img))
QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir))
QMessageBox.information(self,
"Information",
"The following images can not be saved, please check the image path and labels.\n"
+ "".join(str(i) + '\n' for i in ques_img))
QMessageBox.information(self, "Information", "Cropped images have been saved in " + str(crop_img_dir))
def speedChoose(self):
if self.labelDialogOption.isChecked():
......@@ -2162,16 +2189,25 @@ class MainWindow(QMainWindow, WindowMixin):
def autoSaveFunc(self):
if self.autoSaveOption.isChecked():
self.autoSaveNum = 1 # Real auto_Save
self.autoSaveNum = 1 # Real auto_Save
try:
self.saveLabelFile()
except:
pass
print('The program will automatically save once after confirming an image')
else:
self.autoSaveNum = 5 # Used for backup
self.autoSaveNum = 5 # Used for backup
print('The program will automatically save once after confirming 5 images (default)')
def change_box_key(self):
key_text, _ = self.keyDialog.popUp(self.key_previous_text)
if key_text is None:
return
self.key_previous_text = key_text
for shape in self.canvas.selectedShapes:
shape.key_cls = key_text
self._update_shape_color(shape)
def undoShapeEdit(self):
self.canvas.restoreShape()
self.labelList.clear()
......@@ -2186,25 +2222,27 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelList.clearSelection()
self._noSelectionSlot = False
self.canvas.loadShapes(shapes, replace=replace)
print("loadShapes")#1
print("loadShapes") # 1
def lockSelectedShape(self):
"""lock the selsected shapes.
"""lock the selected shapes.
Add self.selectedShapes to lock self.canvas.lockedShapes,
which holds the ratio of the four coordinates of the locked shapes
to the width and height of the image
"""
width, height = self.image.width(), self.image.height()
def format_shape(s):
return dict(label=s.label, # str
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
ratio=[[int(p.x())/width, int(p.y())/height] for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
#lock
ratio=[[int(p.x()) / width, int(p.y()) / height] for p in s.points], # QPonitF
difficult=s.difficult, # bool
key_cls=s.key_cls, # bool
)
# lock
if len(self.canvas.lockedShapes) == 0:
for s in self.canvas.selectedShapes:
s.line_color = DEFAULT_LOCK_COLOR
......@@ -2212,11 +2250,13 @@ class MainWindow(QMainWindow, WindowMixin):
shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
trans_dic = []
for box in shapes:
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']})
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'],
"difficult": box['difficult'],
"key_cls": "None" if "key_cls" not in box else box["key_cls"]})
self.canvas.lockedShapes = trans_dic
self.actions.save.setEnabled(True)
#unlock
# unlock
else:
for s in self.canvas.shapes:
s.line_color = DEFAULT_LINE_COLOR
......@@ -2237,9 +2277,11 @@ def read(filename, default=None):
except:
return default
def str2bool(v):
return v.lower() in ("true", "t", "1")
def get_main_app(argv=[]):
"""
Standard boilerplate Qt application code.
......@@ -2248,23 +2290,26 @@ def get_main_app(argv=[]):
app = QApplication(argv)
app.setApplicationName(__appname__)
app.setWindowIcon(newIcon("app"))
# Tzutalin 201705+: Accept extra agruments to change predefined class file
argparser = argparse.ArgumentParser()
argparser.add_argument("--lang", type=str, default='en', nargs="?")
argparser.add_argument("--gpu", type=str2bool, default=False, nargs="?")
argparser.add_argument("--predefined_classes_file",
default=os.path.join(os.path.dirname(__file__), "data", "predefined_classes.txt"),
nargs="?")
args = argparser.parse_args(argv[1:])
# Usage : labelImg.py image predefClassFile saveDir
win = MainWindow(lang=args.lang, gpu=args.gpu,
defaultPrefdefClassFile=args.predefined_classes_file)
# Tzutalin 201705+: Accept extra arguments to change predefined class file
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--lang", type=str, default='en', nargs="?")
arg_parser.add_argument("--gpu", type=str2bool, default=True, nargs="?")
arg_parser.add_argument("--kie", type=str2bool, default=False, nargs="?")
arg_parser.add_argument("--predefined_classes_file",
default=os.path.join(os.path.dirname(__file__), "data", "predefined_classes.txt"),
nargs="?")
args = arg_parser.parse_args(argv[1:])
win = MainWindow(lang=args.lang,
gpu=args.gpu,
kie_mode=args.kie,
default_predefined_class_file=args.predefined_classes_file)
win.show()
return app, win
def main():
'''construct main app and run it'''
"""construct main app and run it"""
app, _win = get_main_app(sys.argv)
return app.exec_()
......@@ -2276,5 +2321,5 @@ if __name__ == '__main__':
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
"directory resources.py "
import libs.resources
sys.exit(main())
......@@ -8,6 +8,10 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
### Recent Update
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- Added KIE mode, for [detection + identification + keyword extraction] labeling.
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference
- 2021.11.17:
- Support install and start PPOCRLabel through the whl package (by [d2623587501](https://github.com/d2623587501))
- Dataset segmentation: Divide the annotation file into training, verification and testing parts (refer to section 3.5 below, by [MrCuiHao](https://github.com/MrCuiHao))
......@@ -70,14 +74,15 @@ PPOCRLabel
```bash
pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32
PPOCRLabel # run
PPOCRLabel # [Normal mode] for [detection + recognition] labeling
PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
#### 1.2.2 Build and Install the Whl Package Locally
```bash
cd PaddleOCR/PPOCRLabel
python3 setup.py bdist_wheel
python3 setup.py bdist_wheel
pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```
......@@ -85,7 +90,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```bash
cd ./PPOCRLabel # Switch to the PPOCRLabel directory
python PPOCRLabel.py
python PPOCRLabel.py # [Normal mode] for [detection + recognition] labeling
python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
......@@ -110,7 +116,7 @@ python PPOCRLabel.py
6. Click 're-Recognition', model will rewrite ALL recognition results in ALL detection box<sup>[3]</sup>.
7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.
7. Single click the result in 'recognition result' list to manually change inaccurate recognition results.
8. **Click "Check", the image status will switch to "√",then the program automatically jump to the next.**
......@@ -143,15 +149,17 @@ python PPOCRLabel.py
### 3.1 Shortcut keys
| Shortcut keys | Description |
|--------------------------| ------------------------------------------------ |
|--------------------------|--------------------------------------------------|
| Ctrl + Shift + R | Re-recognize all the labels of the current image |
| W | Create a rect box |
| Q | Create a four-points box |
| X | Rotate the box anti-clockwise |
| C | Rotate the box clockwise |
| Ctrl + E | Edit label of the selected box |
| Ctrl + R | Re-recognize the selected box |
| Ctrl + C | Copy and paste the selected box |
| Ctrl + Left Mouse Button | Multi select the label box |
| Ctrl + X | Delete the selected box |
| Alt + X | Delete the selected box |
| Ctrl + V | Check image |
| Ctrl + Shift + d | Delete image |
| D | Next image |
......@@ -167,7 +175,7 @@ python PPOCRLabel.py
- Model language switching: Changing the built-in model language is supportable by clicking "PaddleOCR"-"Choose OCR Model" in the menu bar. Currently supported languages​include French, German, Korean, and Japanese.
For specific model download links, please refer to [PaddleOCR Model List](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md#multilingual-recognition-modelupdating)
- **Custom Model**: If users want to replace the built-in model with their own inference model, they can follow the [Custom Model Code Usage](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_en/whl_en.md#31-use-by-code) by modifying PPOCRLabel.py for [Instantiation of PaddleOCR class](https://github.com/PaddlePaddle/PaddleOCR/blob/release/ 2.3/PPOCRLabel/PPOCRLabel.py#L116) :
- **Custom Model**: If users want to replace the built-in model with their own inference model, they can follow the [Custom Model Code Usage](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_en/whl_en.md#31-use-by-code) by modifying PPOCRLabel.py for [Instantiation of PaddleOCR class](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/PPOCRLabel/PPOCRLabel.py#L86) :
add parameter `det_model_dir` in `self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=gpu, lang=lang) `
......@@ -194,21 +202,31 @@ For some data that are difficult to recognize, the recognition results will not
- Enter the following command in the terminal to execute the dataset division script:
```
```
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
```
Parameter Description:
- `trainValTestRatio` is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is `6:2:2`
- `labelRootPath` is the storage path of the dataset labeled by PPOCRLabel, the default is `../train_data/label`
- `detRootPath` is the path where the text detection dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/det`
- `recRootPath` is the path where the character recognition dataset is divided according to the dataset marked by PPOCRLabel. The default is `../train_data/rec`
- `datasetRootPath` is the storage path of the complete dataset labeled by PPOCRLabel. The default path is `PaddleOCR/train_data` .
```
|-train_data
|-crop_img
|- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 Error message
- If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated.
......@@ -231,4 +249,4 @@ For some data that are difficult to recognize, the recognition results will not
### 4. Related
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
\ No newline at end of file
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
......@@ -8,6 +8,10 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
#### 近期更新
- 2022.02:(by [PeterH0323](https://github.com/peterh0323)
- 新增:KIE 功能,用于打【检测+识别+关键字提取】的标签
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题
- 2021.11.17:
- 新增支持通过whl包安装和启动PPOCRLabel(by [d2623587501](https://github.com/d2623587501)
- 标注数据集切分:对标注数据进行训练、验证与测试集划分(参考下方3.5节,by [MrCuiHao](https://github.com/MrCuiHao)
......@@ -24,7 +28,7 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
- 识别结果更改为单击修改。(如果无法修改,请切换为系统自带输入法,或再次切回原输入法)
- 2020.12.18: 支持对单个标记框进行重新识别(by [ninetailskim](https://github.com/ninetailskim)),完善快捷键。
如果您对以上内容感兴趣或对完善工具有不一样的想法,欢迎加入我们的SIG队伍与我们共同开发。可以在[此处](https://github.com/PaddlePaddle/PaddleOCR/issues/1728)完成问卷和前置任务,经过我们确认相关内容后即可正式加入,享受SIG福利,共同为OCR开源事业贡献(特别说明:针对PPOCRLabel的改进也属于PaddleOCR前置任务)
如果您对完善工具有不一样的想法,欢迎通过[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)报名相关更改,获得积分兑换奖励。
......@@ -68,7 +72,8 @@ PPOCRLabel --lang ch
```bash
pip3 install PPOCRLabel
pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添加"-i https://mirror.baidu.com/pypi/simple"
PPOCRLabel --lang ch # 启动
PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
```
> 如果上述安装出现问题,可以参考3.6节 错误提示
......@@ -87,7 +92,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.
```bash
cd ./PPOCRLabel # 切换到PPOCRLabel目录
python PPOCRLabel.py --lang ch
python PPOCRLabel.py --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
```
......@@ -102,7 +108,7 @@ python PPOCRLabel.py --lang ch
4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>
7. 内容更改:击识别结果,对不准确的识别结果进行手动更改。
7. 内容更改:击识别结果,对不准确的识别结果进行手动更改。
8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。**
9. 删除:点击 “删除图像”,图片将会被删除至回收站。
10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup>
......@@ -131,23 +137,25 @@ python PPOCRLabel.py --lang ch
### 3.1 快捷键
| 快捷键 | 说明 |
|------------------| ---------------------------- |
| Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 |
| Q | 新建四点框 |
| Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 |
| 快捷键 | 说明 |
|------------------|----------------|
| Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 |
| Q | 新建四点框 |
| X | 框逆时针旋转 |
| C | 框顺时针旋转 |
| Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 |
| Ctrl + C | 复制并粘贴选中的标记框 |
| Ctrl + 鼠标左键 | 多选标记框 |
| Ctrl + X | 删除所选框 |
| Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 |
| A | 上一张图片 |
| Ctrl++ | 缩小 |
| Ctrl-- | 放大 |
| ↑→↓← | 移动标记框 |
| Ctrl + 鼠标左键 | 多选标记框 |
| Alt + X | 删除所选框 |
| Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 |
| A | 上一张图片 |
| Ctrl++ | 缩小 |
| Ctrl-- | 放大 |
| ↑→↓← | 移动标记框 |
### 3.2 内置模型
......@@ -181,19 +189,29 @@ PPOCRLabel支持三种导出方式:
```
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data
```
参数说明:
- `trainValTestRatio` 是训练集、验证集、测试集的图像数量划分比例,根据实际情况设定,默认是`6:2:2`
- `labelRootPath` 是PPOCRLabel标注的数据集存放路径,默认是`../train_data/label`
- `detRootPath` 是根据PPOCRLabel标注的数据集划分后的文本检测数据集存放的路径,默认是`../train_data/det `
- `recRootPath` 是根据PPOCRLabel标注的数据集划分后的字符识别数据集存放的路径,默认是`../train_data/rec`
- `datasetRootPath` 是PPOCRLabel标注的完整数据集存放路径。默认路径是 `PaddleOCR/train_data` 分割数据集前应有如下结构:
```
|-train_data
|-crop_img
|- word_001_crop_0.png
|- word_002_crop_0.jpg
|- word_003_crop_0.jpg
| ...
| Label.txt
| rec_gt.txt
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
### 3.6 错误提示
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
......
# Copyright (c) <2015-Present> Tzutalin
# Copyright (C) 2013 MIT, Computer Science and Artificial Intelligence Laboratory. Bryan Russell, Antonio Torralba,
# William T. Freeman. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction, including without
# limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import sys
try:
from PyQt5.QtWidgets import QWidget, QHBoxLayout, QComboBox
except ImportError:
# needed for py3+qt4
# Ref:
# http://pyqt.sourceforge.net/Docs/PyQt4/incompatible_apis.html
# http://stackoverflow.com/questions/21217399/pyqt4-qtcore-qvariant-object-instead-of-a-string
if sys.version_info.major >= 3:
import sip
sip.setapi('QVariant', 2)
from PyQt4.QtGui import QWidget, QHBoxLayout, QComboBox
class ComboBox(QWidget):
def __init__(self, parent=None, items=[]):
super(ComboBox, self).__init__(parent)
layout = QHBoxLayout()
self.cb = QComboBox()
self.items = items
self.cb.addItems(self.items)
self.cb.currentIndexChanged.connect(parent.comboSelectionChanged)
layout.addWidget(self.cb)
self.setLayout(layout)
def update_items(self, items):
self.items = items
self.cb.clear()
self.cb.addItems(self.items)
......@@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集
labelPath = os.path.join(root, dir)
labelAbsPath = os.path.abspath(labelPath)
dataAbsPath = os.path.abspath(root)
if flag == "det":
labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
elif flag == "rec":
labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines()
......@@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath,
imageName = os.path.basename(imageRelativePath)
if flag == "det":
imagePath = os.path.join(labelAbsPath, imageName)
imagePath = os.path.join(dataAbsPath, imageName)
elif flag == "rec":
imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
# 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":")
......@@ -90,15 +89,20 @@ def genDetRecTrainVal(args):
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
for root, dirs, files in os.walk(args.labelRootPath):
splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs:
splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
if dir == 'crop_img':
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
else:
continue
break
if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
......@@ -110,9 +114,9 @@ if __name__ == "__main__":
default="6:2:2",
help="ratio of trainset:valset:testset")
parser.add_argument(
"--labelRootPath",
"--datasetRootPath",
type=str,
default="../train_data/label",
default="../train_data/",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
)
parser.add_argument(
......
......@@ -11,19 +11,13 @@
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
#from PyQt4.QtOpenGL import *
import copy
from PyQt5.QtCore import Qt, pyqtSignal, QPointF, QPoint
from PyQt5.QtGui import QPainter, QBrush, QColor, QPixmap
from PyQt5.QtWidgets import QWidget, QMenu, QApplication
from libs.shape import Shape
from libs.utils import distance
import copy
CURSOR_DEFAULT = Qt.ArrowCursor
CURSOR_POINT = Qt.PointingHandCursor
......@@ -31,8 +25,6 @@ CURSOR_DRAW = Qt.CrossCursor
CURSOR_MOVE = Qt.ClosedHandCursor
CURSOR_GRAB = Qt.OpenHandCursor
# class Canvas(QGLWidget):
class Canvas(QWidget):
zoomRequest = pyqtSignal(int)
......@@ -129,7 +121,6 @@ class Canvas(QWidget):
def selectedVertex(self):
return self.hVertex is not None
def mouseMoveEvent(self, ev):
"""Update line with last point and current coordinates."""
pos = self.transformPos(ev.pos())
......@@ -333,7 +324,6 @@ class Canvas(QWidget):
self.movingShape = False
def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
......@@ -410,7 +400,6 @@ class Canvas(QWidget):
self.selectionChanged.emit(shapes)
self.update()
def selectShapePoint(self, point, multiple_selection_mode):
"""Select the first shape created which contains this point."""
if self.selectedVertex(): # A vertex is marked for selection.
......@@ -494,7 +483,6 @@ class Canvas(QWidget):
else:
shape.moveVertexBy(index, shiftPos)
def boundedMoveShape(self, shapes, pos):
if type(shapes).__name__ != 'list': shapes = [shapes]
if self.outOfPixmap(pos):
......@@ -515,6 +503,7 @@ class Canvas(QWidget):
if dp:
for shape in shapes:
shape.moveBy(dp)
shape.close()
self.prevPoint = pos
return True
return False
......@@ -728,6 +717,31 @@ class Canvas(QWidget):
self.moveOnePixel('Up')
elif key == Qt.Key_Down and self.selectedShapes:
self.moveOnePixel('Down')
elif key == Qt.Key_X and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(0.01):
continue
self.selectedShape.rotate(0.01)
self.shapeMoved.emit()
self.update()
elif key == Qt.Key_C and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(-0.01):
continue
self.selectedShape.rotate(-0.01)
self.shapeMoved.emit()
self.update()
def rotateOutOfBound(self, angle):
for shape in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[shape]
for i, p in enumerate(self.selectedShape.points):
if self.outOfPixmap(self.selectedShape.rotatePoint(p, angle)):
return True
return False
def moveOnePixel(self, direction):
# print(self.selectedShape.points)
......@@ -769,7 +783,7 @@ class Canvas(QWidget):
points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)]
return True in map(self.outOfPixmap, points)
def setLastLabel(self, text, line_color = None, fill_color = None):
def setLastLabel(self, text, line_color=None, fill_color=None, key_cls=None):
assert text
self.shapes[-1].label = text
if line_color:
......@@ -777,6 +791,10 @@ class Canvas(QWidget):
if fill_color:
self.shapes[-1].fill_color = fill_color
if key_cls:
self.shapes[-1].key_cls = key_cls
self.storeShapes()
return self.shapes[-1]
......
import sys, time
from PyQt5 import QtWidgets
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
# !/usr/bin/env python
# -*- coding: utf-8 -*-
from PyQt5.QtCore import QModelIndex
from PyQt5.QtWidgets import QListWidget
class EditInList(QListWidget):
def __init__(self):
super(EditInList,self).__init__()
# click to edit
self.clicked.connect(self.item_clicked)
super(EditInList, self).__init__()
self.edited_item = None
def item_clicked(self, modelindex: QModelIndex):
try:
if self.edited_item is not None:
self.closePersistentEditor(self.edited_item)
except:
self.edited_item = self.currentItem()
def item_clicked(self, modelindex: QModelIndex) -> None:
self.edited_item = self.currentItem()
self.closePersistentEditor(self.edited_item)
item = self.item(modelindex.row())
# time.sleep(0.2)
self.edited_item = item
self.openPersistentEditor(item)
# time.sleep(0.2)
self.editItem(item)
self.edited_item = self.item(modelindex.row())
self.openPersistentEditor(self.edited_item)
self.editItem(self.edited_item)
def mouseDoubleClickEvent(self, event):
# close edit
for i in range(self.count()):
self.closePersistentEditor(self.item(i))
pass
def leaveEvent(self, event):
# close edit
for i in range(self.count()):
self.closePersistentEditor(self.item(i))
\ No newline at end of file
self.closePersistentEditor(self.item(i))
import re
from PyQt5 import QtCore
from PyQt5 import QtGui
from PyQt5 import QtWidgets
from PyQt5.Qt import QT_VERSION_STR
from libs.utils import newIcon, labelValidator
QT5 = QT_VERSION_STR[0] == '5'
# TODO(unknown):
# - Calculate optimal position so as not to go out of screen area.
class KeyQLineEdit(QtWidgets.QLineEdit):
def setListWidget(self, list_widget):
self.list_widget = list_widget
def keyPressEvent(self, e):
if e.key() in [QtCore.Qt.Key_Up, QtCore.Qt.Key_Down]:
self.list_widget.keyPressEvent(e)
else:
super(KeyQLineEdit, self).keyPressEvent(e)
class KeyDialog(QtWidgets.QDialog):
def __init__(
self,
text="Enter object label",
parent=None,
labels=None,
sort_labels=True,
show_text_field=True,
completion="startswith",
fit_to_content=None,
flags=None,
):
if fit_to_content is None:
fit_to_content = {"row": False, "column": True}
self._fit_to_content = fit_to_content
super(KeyDialog, self).__init__(parent)
self.edit = KeyQLineEdit()
self.edit.setPlaceholderText(text)
self.edit.setValidator(labelValidator())
self.edit.editingFinished.connect(self.postProcess)
if flags:
self.edit.textChanged.connect(self.updateFlags)
layout = QtWidgets.QVBoxLayout()
if show_text_field:
layout_edit = QtWidgets.QHBoxLayout()
layout_edit.addWidget(self.edit, 6)
layout.addLayout(layout_edit)
# buttons
self.buttonBox = bb = QtWidgets.QDialogButtonBox(
QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
QtCore.Qt.Horizontal,
self,
)
bb.button(bb.Ok).setIcon(newIcon("done"))
bb.button(bb.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
# label_list
self.labelList = QtWidgets.QListWidget()
if self._fit_to_content["row"]:
self.labelList.setHorizontalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
if self._fit_to_content["column"]:
self.labelList.setVerticalScrollBarPolicy(
QtCore.Qt.ScrollBarAlwaysOff
)
self._sort_labels = sort_labels
if labels:
self.labelList.addItems(labels)
if self._sort_labels:
self.labelList.sortItems()
else:
self.labelList.setDragDropMode(
QtWidgets.QAbstractItemView.InternalMove
)
self.labelList.currentItemChanged.connect(self.labelSelected)
self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
self.edit.setListWidget(self.labelList)
layout.addWidget(self.labelList)
# label_flags
if flags is None:
flags = {}
self._flags = flags
self.flagsLayout = QtWidgets.QVBoxLayout()
self.resetFlags()
layout.addItem(self.flagsLayout)
self.edit.textChanged.connect(self.updateFlags)
self.setLayout(layout)
# completion
completer = QtWidgets.QCompleter()
if not QT5 and completion != "startswith":
completion = "startswith"
if completion == "startswith":
completer.setCompletionMode(QtWidgets.QCompleter.InlineCompletion)
# Default settings.
# completer.setFilterMode(QtCore.Qt.MatchStartsWith)
elif completion == "contains":
completer.setCompletionMode(QtWidgets.QCompleter.PopupCompletion)
completer.setFilterMode(QtCore.Qt.MatchContains)
else:
raise ValueError("Unsupported completion: {}".format(completion))
completer.setModel(self.labelList.model())
self.edit.setCompleter(completer)
def addLabelHistory(self, label):
if self.labelList.findItems(label, QtCore.Qt.MatchExactly):
return
self.labelList.addItem(label)
if self._sort_labels:
self.labelList.sortItems()
def labelSelected(self, item):
self.edit.setText(item.text())
def validate(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
if text:
self.accept()
def labelDoubleClicked(self, item):
self.validate()
def postProcess(self):
text = self.edit.text()
if hasattr(text, "strip"):
text = text.strip()
else:
text = text.trimmed()
self.edit.setText(text)
def updateFlags(self, label_new):
# keep state of shared flags
flags_old = self.getFlags()
flags_new = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label_new):
for key in keys:
flags_new[key] = flags_old.get(key, False)
self.setFlags(flags_new)
def deleteFlags(self):
for i in reversed(range(self.flagsLayout.count())):
item = self.flagsLayout.itemAt(i).widget()
self.flagsLayout.removeWidget(item)
item.setParent(None)
def resetFlags(self, label=""):
flags = {}
for pattern, keys in self._flags.items():
if re.match(pattern, label):
for key in keys:
flags[key] = False
self.setFlags(flags)
def setFlags(self, flags):
self.deleteFlags()
for key in flags:
item = QtWidgets.QCheckBox(key, self)
item.setChecked(flags[key])
self.flagsLayout.addWidget(item)
item.show()
def getFlags(self):
flags = {}
for i in range(self.flagsLayout.count()):
item = self.flagsLayout.itemAt(i).widget()
flags[item.text()] = item.isChecked()
return flags
def popUp(self, text=None, move=True, flags=None):
if self._fit_to_content["row"]:
self.labelList.setMinimumHeight(
self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
)
if self._fit_to_content["column"]:
self.labelList.setMinimumWidth(
self.labelList.sizeHintForColumn(0) + 2
)
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
self.edit.setText(text)
self.edit.setSelection(0, len(text))
items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
if items:
if len(items) != 1:
self.labelList.setCurrentItem(items[0])
row = self.labelList.row(items[0])
self.edit.completer().setCurrentRow(row)
self.edit.setFocus(QtCore.Qt.PopupFocusReason)
if move:
self.move(QtGui.QCursor.pos())
if self.exec_():
return self.edit.text(), self.getFlags()
else:
return None, None
import PIL.Image
import numpy as np
def rgb2hsv(rgb):
# type: (np.ndarray) -> np.ndarray
"""Convert rgb to hsv.
Parameters
----------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Input rgb image.
Returns
-------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Output hsv image.
"""
hsv = PIL.Image.fromarray(rgb, mode="RGB")
hsv = hsv.convert("HSV")
hsv = np.array(hsv)
return hsv
def hsv2rgb(hsv):
# type: (np.ndarray) -> np.ndarray
"""Convert hsv to rgb.
Parameters
----------
hsv: numpy.ndarray, (H, W, 3), np.uint8
Input hsv image.
Returns
-------
rgb: numpy.ndarray, (H, W, 3), np.uint8
Output rgb image.
"""
rgb = PIL.Image.fromarray(hsv, mode="HSV")
rgb = rgb.convert("RGB")
rgb = np.array(rgb)
return rgb
def label_colormap(n_label=256, value=None):
"""Label colormap.
Parameters
----------
n_label: int
Number of labels (default: 256).
value: float or int
Value scale or value of label color in HSV space.
Returns
-------
cmap: numpy.ndarray, (N, 3), numpy.uint8
Label id to colormap.
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
cmap = np.zeros((n_label, 3), dtype=np.uint8)
for i in range(0, n_label):
id = i
r, g, b = 0, 0, 0
for j in range(0, 8):
r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
if value is not None:
hsv = rgb2hsv(cmap.reshape(1, -1, 3))
if isinstance(value, float):
hsv[:, 1:, 2] = hsv[:, 1:, 2].astype(float) * value
else:
assert isinstance(value, int)
hsv[:, 1:, 2] = value
cmap = hsv2rgb(hsv).reshape(-1, 3)
return cmap
......@@ -10,19 +10,14 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#!/usr/bin/python
# !/usr/bin/python
# -*- coding: utf-8 -*-
import math
import sys
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
from PyQt5.QtCore import QPointF
from PyQt5.QtGui import QColor, QPen, QPainterPath, QFont
from libs.utils import distance
import sys
DEFAULT_LINE_COLOR = QColor(0, 255, 0, 128)
DEFAULT_FILL_COLOR = QColor(255, 0, 0, 128)
......@@ -51,14 +46,17 @@ class Shape(object):
point_size = 8
scale = 1.0
def __init__(self, label=None, line_color=None, difficult=False, paintLabel=False):
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
self.label = label
self.points = []
self.fill = False
self.selected = False
self.difficult = difficult
self.key_cls = key_cls
self.paintLabel = paintLabel
self.locked = False
self.direction = 0
self.center = None
self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX
self._highlightSettings = {
......@@ -74,7 +72,24 @@ class Shape(object):
# is used for drawing the pending line a different color.
self.line_color = line_color
def rotate(self, theta):
for i, p in enumerate(self.points):
self.points[i] = self.rotatePoint(p, theta)
self.direction -= theta
self.direction = self.direction % (2 * math.pi)
def rotatePoint(self, p, theta):
order = p - self.center
cosTheta = math.cos(theta)
sinTheta = math.sin(theta)
pResx = cosTheta * order.x() + sinTheta * order.y()
pResy = - sinTheta * order.x() + cosTheta * order.y()
pRes = QPointF(self.center.x() + pResx, self.center.y() + pResy)
return pRes
def close(self):
self.center = QPointF((self.points[0].x() + self.points[2].x()) / 2,
(self.points[0].y() + self.points[2].y()) / 2)
self._closed = True
def reachMaxPoints(self):
......@@ -83,7 +98,9 @@ class Shape(object):
return False
def addPoint(self, point):
if not self.reachMaxPoints(): # 4个点时发出close信号
if self.reachMaxPoints():
self.close()
else:
self.points.append(point)
def popPoint(self):
......@@ -112,7 +129,7 @@ class Shape(object):
# Uncommenting the following line will draw 2 paths
# for the 1st vertex, and make it non-filled, which
# may be desirable.
#self.drawVertex(vrtx_path, 0)
# self.drawVertex(vrtx_path, 0)
for i, p in enumerate(self.points):
line_path.lineTo(p)
......@@ -136,9 +153,9 @@ class Shape(object):
font.setPointSize(8)
font.setBold(True)
painter.setFont(font)
if(self.label == None):
if self.label is None:
self.label = ""
if(min_y < MIN_Y_LABEL):
if min_y < MIN_Y_LABEL:
min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, self.label)
......@@ -198,6 +215,8 @@ class Shape(object):
def copy(self):
shape = Shape("%s" % self.label)
shape.points = [p for p in self.points]
shape.center = self.center
shape.direction = self.direction
shape.fill = self.fill
shape.selected = self.selected
shape._closed = self._closed
......@@ -206,6 +225,7 @@ class Shape(object):
if self.fill_color != Shape.fill_color:
shape.fill_color = self.fill_color
shape.difficult = self.difficult
shape.key_cls = self.key_cls
return shape
def __len__(self):
......
# -*- encoding: utf-8 -*-
from PyQt5.QtCore import Qt
from PyQt5 import QtWidgets
class EscapableQListWidget(QtWidgets.QListWidget):
def keyPressEvent(self, event):
super(EscapableQListWidget, self).keyPressEvent(event)
if event.key() == Qt.Key_Escape:
self.clearSelection()
class UniqueLabelQListWidget(EscapableQListWidget):
def mousePressEvent(self, event):
super(UniqueLabelQListWidget, self).mousePressEvent(event)
if not self.indexAt(event.pos()).isValid():
self.clearSelection()
def findItemsByLabel(self, label, get_row=False):
items = []
for row in range(self.count()):
item = self.item(row)
if item.data(Qt.UserRole) == label:
items.append(item)
if get_row:
return row
return items
def createItemFromLabel(self, label):
item = QtWidgets.QListWidgetItem()
item.setData(Qt.UserRole, label)
return item
def setItemLabel(self, item, label, color=None):
qlabel = QtWidgets.QLabel()
if color is None:
qlabel.setText(f"{label}")
else:
qlabel.setText('<font color="#{:02x}{:02x}{:02x}">●</font> {} '.format(*color, label))
qlabel.setAlignment(Qt.AlignBottom)
item.setSizeHint(qlabel.sizeHint())
self.setItemWidget(item, qlabel)
......@@ -10,30 +10,26 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from math import sqrt
from libs.ustr import ustr
import hashlib
import os
import re
import sys
from math import sqrt
import cv2
import numpy as np
import os
from PyQt5.QtCore import QRegExp, QT_VERSION_STR
from PyQt5.QtGui import QIcon, QRegExpValidator, QColor
from PyQt5.QtWidgets import QPushButton, QAction, QMenu
from libs.ustr import ustr
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
__iconpath__ = os.path.abspath(os.path.join(__dir__, '../resources/icons'))
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
def newIcon(icon, iconSize=None):
if iconSize is not None:
return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize,iconSize))
return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize, iconSize))
else:
return QIcon(__iconpath__ + "/" + icon + ".png")
......@@ -105,24 +101,25 @@ def generateColorByText(text):
s = ustr(text)
hashCode = int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16)
r = int((hashCode / 255) % 255)
g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255)
g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255)
return QColor(r, g, b, 100)
def have_qstring():
'''p3/qt5 get rid of QString wrapper as py3 has native unicode str type'''
return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith('5.'))
def util_qt_strlistclass():
return QStringList if have_qstring() else list
def natural_sort(list, key=lambda s:s):
def natural_sort(list, key=lambda s: s):
"""
Sort the list into natural alphanumeric order.
"""
def get_alphanum_key_func(key):
convert = lambda text: int(text) if text.isdigit() else text
return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
sort_key = get_alphanum_key_func(key)
list.sort(key=sort_key)
......@@ -133,8 +130,8 @@ def get_rotate_crop_image(img, points):
d = 0.0
for index in range(-1, 3):
d += -0.5 * (points[index + 1][1] + points[index][1]) * (
points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise
points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise
tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1]
......@@ -163,10 +160,11 @@ def get_rotate_crop_image(img, points):
except Exception as e:
print(e)
def stepsInfo(lang='en'):
if lang == 'ch':
msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \
"2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n"\
"2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n" \
"3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n" \
"4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动" \
"绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n" \
......@@ -181,25 +179,26 @@ def stepsInfo(lang='en'):
else:
msg = "1. Build and launch using the instructions above.\n" \
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
"3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name."\
"4. Create Box:\n"\
"4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n"\
"4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n"\
"5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n"\
"6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n"\
"7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n"\
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n" \
"3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name." \
"4. Create Box:\n" \
"4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n" \
"4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n" \
"5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n" \
"6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n" \
"7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n" \
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n" \
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n" \
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n" \
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
return msg
def keysInfo(lang='en'):
if lang == 'ch':
msg = "快捷键\t\t\t说明\n" \
"———————————————————————\n"\
"———————————————————————\n" \
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
"W\t\t\t新建矩形框\n" \
"Q\t\t\t新建四点框\n" \
......@@ -223,17 +222,17 @@ def keysInfo(lang='en'):
"———————————————————————\n" \
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
"\t\t\tof the current image\n" \
"\n"\
"\n" \
"W\t\t\tCreate a rect box\n" \
"Q\t\t\tCreate a four-points box\n" \
"Ctrl + E\t\tEdit label of the selected box\n" \
"Ctrl + R\t\tRe-recognize the selected box\n" \
"Ctrl + C\t\tCopy and paste the selected\n" \
"\t\t\tbox\n" \
"\n"\
"\n" \
"Ctrl + Left Mouse\tMulti select the label\n" \
"Button\t\t\tbox\n" \
"\n"\
"\n" \
"Backspace\t\tDelete the selected box\n" \
"Ctrl + V\t\tCheck image\n" \
"Ctrl + Shift + d\tDelete image\n" \
......@@ -245,4 +244,4 @@ def keysInfo(lang='en'):
"———————————————————————\n" \
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
return msg
\ No newline at end of file
return msg
......@@ -106,4 +106,7 @@ undo=Undo
undoLastPoint=Undo Last Point
autoSaveMode=Auto Export Label Mode
lockBox=Lock selected box/Unlock all box
lockBoxDetail=Lock selected box/Unlock all box
\ No newline at end of file
lockBoxDetail=Lock selected box/Unlock all box
keyListTitle=Key List
keyDialogTip=Enter object label
keyChange=Change Box Key
......@@ -107,3 +107,6 @@ undoLastPoint=撤销上个点
autoSaveMode=自动导出标记结果
lockBox=锁定框/解除锁定框
lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态
keyListTitle=关键词列表
keyDialogTip=请输入类型名称
keyChange=更改Box关键字类别
\ No newline at end of file
......@@ -92,7 +92,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
| ------------------------------------------------------------ | ---------------------------- | ----------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Chinese and English ultra-lightweight PP-OCRv2 model(11.6M) | ch_PP-OCRv2_xx |Mobile & Server|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar)|
| Chinese and English ultra-lightweight PP-OCR model (9.4M) | ch_ppocr_mobile_v2.0_xx | Mobile & server |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) |
| Chinese and English general PP-OCR model (143.4M) | ch_ppocr_server_v2.0_xx | Server |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_traingit.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) |
| Chinese and English general PP-OCR model (143.4M) | ch_ppocr_server_v2.0_xx | Server |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) |
For more model downloads (including multiple languages), please refer to [PP-OCR series model downloads](./doc/doc_en/models_list_en.md).
......@@ -152,7 +152,7 @@ For a new language request, please refer to [Guideline for new language_requests
[1] PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection, detection frame correction and CRNN text recognition. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module (as shown in the green box above). The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941).
[2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (arXiv link is coming soon).
[2] On the basis of PP-OCR, PP-OCRv2 is further optimized in five aspects. The detection model adopts CML(Collaborative Mutual Learning) knowledge distillation strategy and CopyPaste data expansion strategy. The recognition model adopts LCNet lightweight backbone network, U-DML knowledge distillation strategy and enhanced CTC loss function improvement (as shown in the red box above), which further improves the inference speed and prediction effect. For more details, please refer to the technical report of PP-OCRv2 (https://arxiv.org/abs/2109.03144).
......@@ -181,16 +181,11 @@ For a new language request, please refer to [Guideline for new language_requests
<a name="language_requests"></a>
## Guideline for New Language Requests
If you want to request a new language support, a PR with 2 following files are needed:
If you want to request a new language support, a PR with 1 following files are needed:
1. In folder [ppocr/utils/dict](./ppocr/utils/dict),
it is necessary to submit the dict text to this path and name it with `{language}_dict.txt` that contains a list of all characters. Please see the format example from other files in that folder.
2. In folder [ppocr/utils/corpus](./ppocr/utils/corpus),
it is necessary to submit the corpus to this path and name it with `{language}_corpus.txt` that contains a list of words in your language.
Maybe, 50000 words per language is necessary at least.
Of course, the more, the better.
If your language has unique elements, please tell me in advance within any way, such as useful links, wikipedia and so on.
More details, please refer to [Multilingual OCR Development Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
......
......@@ -99,7 +99,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
- [PP-Structure信息提取](./ppstructure/README_ch.md)
- [版面分析](./ppstructure/layout/README_ch.md)
- [表格识别](./ppstructure/table/README_ch.md)
- [DocVQA](./ppstructure/vqa/README_ch.md)
- [DocVQA](./ppstructure/vqa/README.md)
- [关键信息提取](./ppstructure/docs/kie.md)
- OCR学术圈
- [两阶段模型介绍与下载](./doc/doc_ch/algorithm_overview.md)
......
......@@ -26,35 +26,57 @@ def parse_args():
parser.add_argument(
"--filename", type=str, help="The name of log which need to analysis.")
parser.add_argument(
"--log_with_profiler", type=str, help="The path of train log with profiler")
"--log_with_profiler",
type=str,
help="The path of train log with profiler")
parser.add_argument(
"--profiler_path", type=str, help="The path of profiler timeline log.")
parser.add_argument(
"--keyword", type=str, help="Keyword to specify analysis data")
parser.add_argument(
"--separator", type=str, default=None, help="Separator of different field in log")
"--separator",
type=str,
default=None,
help="Separator of different field in log")
parser.add_argument(
'--position', type=int, default=None, help='The position of data field')
parser.add_argument(
'--range', type=str, default="", help='The range of data field to intercept')
'--range',
type=str,
default="",
help='The range of data field to intercept')
parser.add_argument(
'--base_batch_size', type=int, help='base_batch size on gpu')
parser.add_argument(
'--skip_steps', type=int, default=0, help='The number of steps to be skipped')
'--skip_steps',
type=int,
default=0,
help='The number of steps to be skipped')
parser.add_argument(
'--model_mode', type=int, default=-1, help='Analysis mode, default value is -1')
'--model_mode',
type=int,
default=-1,
help='Analysis mode, default value is -1')
parser.add_argument('--ips_unit', type=str, default=None, help='IPS unit')
parser.add_argument(
'--ips_unit', type=str, default=None, help='IPS unit')
parser.add_argument(
'--model_name', type=str, default=0, help='training model_name, transformer_base')
'--model_name',
type=str,
default=0,
help='training model_name, transformer_base')
parser.add_argument(
'--mission_name', type=str, default=0, help='training mission name')
parser.add_argument(
'--direction_id', type=int, default=0, help='training direction_id')
parser.add_argument(
'--run_mode', type=str, default="sp", help='multi process or single process')
'--run_mode',
type=str,
default="sp",
help='multi process or single process')
parser.add_argument(
'--index', type=int, default=1, help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
'--index',
type=int,
default=1,
help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
parser.add_argument(
'--gpu_num', type=int, default=1, help='nums of training gpus')
args = parser.parse_args()
......@@ -72,7 +94,12 @@ def _is_number(num):
class TimeAnalyzer(object):
def __init__(self, filename, keyword=None, separator=None, position=None, range="-1"):
def __init__(self,
filename,
keyword=None,
separator=None,
position=None,
range="-1"):
if filename is None:
raise Exception("Please specify the filename!")
......@@ -99,7 +126,8 @@ class TimeAnalyzer(object):
# Distil the string from a line.
line = line.strip()
line_words = line.split(self.separator) if self.separator else line.split()
line_words = line.split(
self.separator) if self.separator else line.split()
if args.position:
result = line_words[self.position]
else:
......@@ -108,27 +136,36 @@ class TimeAnalyzer(object):
if line_words[i] == self.keyword:
result = line_words[i + 1]
break
# Distil the result from the picked string.
if not self.range:
result = result[0:]
elif _is_number(self.range):
result = result[0: int(self.range)]
result = result[0:int(self.range)]
else:
result = result[int(self.range.split(":")[0]): int(self.range.split(":")[1])]
result = result[int(self.range.split(":")[0]):int(
self.range.split(":")[1])]
self.records.append(float(result))
except Exception as exc:
print("line is: {}; separator={}; position={}".format(line, self.separator, self.position))
print("line is: {}; separator={}; position={}".format(
line, self.separator, self.position))
print("Extract {} records: separator={}; position={}".format(len(self.records), self.separator, self.position))
print("Extract {} records: separator={}; position={}".format(
len(self.records), self.separator, self.position))
def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None):
def _get_fps(self,
mode,
batch_size,
gpu_num,
avg_of_records,
run_mode,
unit=None):
if mode == -1 and run_mode == 'sp':
assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records
elif mode == -1 and run_mode == 'mp':
assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records #temporarily, not used now
fps = gpu_num * avg_of_records #temporarily, not used now
print("------------this is mp")
elif mode == 0:
# s/step -> samples/s
......@@ -155,12 +192,20 @@ class TimeAnalyzer(object):
return fps, unit
def analysis(self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode='sp', unit=None):
def analysis(self,
batch_size,
gpu_num=1,
skip_steps=0,
mode=-1,
run_mode='sp',
unit=None):
if batch_size <= 0:
print("base_batch_size should larger than 0.")
return 0, ''
if len(self.records) <= skip_steps: # to address the condition which item of log equals to skip_steps
if len(
self.records
) <= skip_steps: # to address the condition which item of log equals to skip_steps
print("no records")
return 0, ''
......@@ -180,16 +225,20 @@ class TimeAnalyzer(object):
skip_max = self.records[i]
avg_of_records = sum_of_records / float(count)
avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps)
avg_of_records_skipped = sum_of_records_skipped / float(count -
skip_steps)
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, run_mode, unit)
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit)
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records,
run_mode, unit)
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num,
avg_of_records_skipped, run_mode, unit)
if mode == -1:
print("average ips of %d steps, skip 0 step:" % count)
print("\tAvg: %.3f %s" % (avg_of_records, fps_unit))
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average ips of %d steps, skip %d steps:" % (count, skip_steps))
print("average ips of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit))
print("\tMin: %.3f %s" % (skip_min, fps_unit))
print("\tMax: %.3f %s" % (skip_max, fps_unit))
......@@ -199,7 +248,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f steps/s" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f steps/s" % avg_of_records_skipped)
print("\tMin: %.3f steps/s" % skip_min)
print("\tMax: %.3f steps/s" % skip_max)
......@@ -209,7 +259,8 @@ class TimeAnalyzer(object):
print("\tAvg: %.3f s/step" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
print("average latency of %d steps, skip %d steps:" %
(count, skip_steps))
print("\tAvg: %.3f s/step" % avg_of_records_skipped)
print("\tMin: %.3f s/step" % skip_min)
print("\tMax: %.3f s/step" % skip_max)
......@@ -236,7 +287,8 @@ if __name__ == "__main__":
if args.gpu_num == 1:
run_info["log_with_profiler"] = args.log_with_profiler
run_info["profiler_path"] = args.profiler_path
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, args.position, args.range)
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator,
args.position, args.range)
run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis(
batch_size=args.base_batch_size,
gpu_num=args.gpu_num,
......@@ -245,29 +297,50 @@ if __name__ == "__main__":
run_mode=args.run_mode,
unit=args.ips_unit)
try:
if int(os.getenv('job_fail_flag')) == 1 or int(run_info["FINAL_RESULT"]) == 0:
if int(os.getenv('job_fail_flag')) == 1 or int(run_info[
"FINAL_RESULT"]) == 0:
run_info["JOB_FAIL_FLAG"] = 1
except:
pass
elif args.index == 3:
run_info["FINAL_RESULT"] = {}
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', None, 3, '').records
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', None, 5).records
records_ct_total = TimeAnalyzer(args.filename, 'Computation time', None, 3, '').records
records_gm_total = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 4, '').records
records_gm_ratio = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 6).records
records_gmas_total = TimeAnalyzer(args.filename, 'GpuMemcpyAsync Calls', None, 4, '').records
records_gms_total = TimeAnalyzer(args.filename, 'GpuMemcpySync Calls', None, 4, '').records
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[0] if records_fo_total else 0
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[0] if records_fo_ratio else 0
run_info["FINAL_RESULT"]["ComputationTime_Total"] = records_ct_total[0] if records_ct_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[0] if records_gm_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[0] if records_gm_ratio else 0
run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = records_gmas_total[0] if records_gmas_total else 0
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[0] if records_gms_total else 0
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead',
None, 3, '').records
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead',
None, 5).records
records_ct_total = TimeAnalyzer(args.filename, 'Computation time',
None, 3, '').records
records_gm_total = TimeAnalyzer(args.filename,
'GpuMemcpy Calls',
None, 4, '').records
records_gm_ratio = TimeAnalyzer(args.filename,
'GpuMemcpy Calls',
None, 6).records
records_gmas_total = TimeAnalyzer(args.filename,
'GpuMemcpyAsync Calls',
None, 4, '').records
records_gms_total = TimeAnalyzer(args.filename,
'GpuMemcpySync Calls',
None, 4, '').records
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[
0] if records_fo_total else 0
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[
0] if records_fo_ratio else 0
run_info["FINAL_RESULT"][
"ComputationTime_Total"] = records_ct_total[
0] if records_ct_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[
0] if records_gm_total else 0
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[
0] if records_gm_ratio else 0
run_info["FINAL_RESULT"][
"GpuMemcpyAsync_Total"] = records_gmas_total[
0] if records_gmas_total else 0
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[
0] if records_gms_total else 0
else:
print("Not support!")
except Exception:
traceback.print_exc()
print("{}".format(json.dumps(run_info))) # it's required, for the log file path insert to the database
traceback.print_exc()
print("{}".format(json.dumps(run_info))
) # it's required, for the log file path insert to the database
......@@ -58,3 +58,4 @@ source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合
_set_params $@
#_train # 如果只想产出训练log,不解析,可取消注释
_run # 该函数在run_model.sh中,执行时会调用_train; 如果不联调只想要产出训练log可以注掉本行,提交时需打开
......@@ -36,3 +36,4 @@ for model_mode in ${model_mode_list[@]}; do
done
Global:
use_gpu: true
use_xpu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment