"models/vision/vscode:/vscode.git/clone" did not exist on "2db090ded903d9b78eba3a42fdb21a1893a4cc86"
Commit 8259d256 authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into fix_vqa

parents 5feb969e 3b646d4f
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator from data_collator import DataCollator
from metric import re_score from metric import re_score
......
...@@ -33,7 +33,7 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMFor ...@@ -33,7 +33,7 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMFor
from xfun import XFUNDataset from xfun import XFUNDataset
from losses import SERLoss from losses import SERLoss
from utils import parse_args, get_bio_label_maps, print_arguments from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
...@@ -15,7 +15,7 @@ import paddle ...@@ -15,7 +15,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, draw_re_results from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator from data_collator import DataCollator
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json import json
import cv2 import cv2
import numpy as np import numpy as np
...@@ -22,7 +26,7 @@ from copy import deepcopy ...@@ -22,7 +26,7 @@ from copy import deepcopy
import paddle import paddle
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json import json
import cv2 import cv2
import numpy as np import numpy as np
...@@ -25,9 +29,16 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -25,9 +29,16 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
MODELS = { MODELS = {
'LayoutXLM': 'LayoutXLM':
......
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_re_results from vqa_utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor from infer_ser_e2e import SerPredictor
......
...@@ -27,7 +27,7 @@ import paddle ...@@ -27,7 +27,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator from data_collator import DataCollator
from eval_re import evaluate from eval_re import evaluate
......
...@@ -32,7 +32,7 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -32,7 +32,7 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate from eval_ser import evaluate
from losses import SERLoss from losses import SERLoss
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
...@@ -126,9 +126,6 @@ def main(): ...@@ -126,9 +126,6 @@ def main():
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode()) fout.write(otstr.encode())
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
draw_det_res(boxes, config, src_img, file, save_det_path)
logger.info("success!") logger.info("success!")
......
...@@ -33,8 +33,9 @@ import paddle ...@@ -33,8 +33,9 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
import time
def read_class_list(filepath): def read_class_list(filepath):
...@@ -80,7 +81,8 @@ def draw_kie_result(batch, node, idx_to_cls, count): ...@@ -80,7 +81,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img vis_img[:, :w] = img
vis_img[:, w:] = pred_img vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" save_kie_path = os.path.dirname(config['Global'][
'save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path): if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path) os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png") save_path = os.path.join(save_kie_path, str(count) + ".png")
...@@ -93,7 +95,7 @@ def main(): ...@@ -93,7 +95,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -111,10 +113,15 @@ def main(): ...@@ -111,10 +113,15 @@ def main():
os.makedirs(os.path.dirname(save_res_path)) os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
warmup_times = 0
count_t = []
with open(save_res_path, "wb") as fout: with open(save_res_path, "wb") as fout:
with open(config['Global']['infer_img'], "rb") as f: with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines() lines = f.readlines()
for index, data_line in enumerate(lines): for index, data_line in enumerate(lines):
if index == 10:
warmup_t = time.time()
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t") substr = data_line.strip("\n").split("\t")
img_path, label = data_dir + "/" + substr[0], substr[1] img_path, label = data_dir + "/" + substr[0], substr[1]
...@@ -122,16 +129,23 @@ def main(): ...@@ -122,16 +129,23 @@ def main():
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
st = time.time()
batch = transform(data, ops) batch = transform(data, ops)
batch_pred = [0] * len(batch) batch_pred = [0] * len(batch)
for i in range(len(batch)): for i in range(len(batch)):
batch_pred[i] = paddle.to_tensor( batch_pred[i] = paddle.to_tensor(
np.expand_dims( np.expand_dims(
batch[i], axis=0)) batch[i], axis=0))
st = time.time()
node, edge = model(batch_pred) node, edge = model(batch_pred)
node = F.softmax(node, -1) node = F.softmax(node, -1)
count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index) draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!") logger.info("success!")
logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t)))
ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
logger.info("The ips is {} images/s".format(ips))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -227,10 +227,6 @@ def train(config, ...@@ -227,10 +227,6 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
train_start = time.time() train_start = time.time()
# use amp # use amp
...@@ -243,6 +239,8 @@ def train(config, ...@@ -243,6 +239,8 @@ def train(config,
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type == "kie":
preds = model(batch)
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -403,7 +401,7 @@ def eval(model, ...@@ -403,7 +401,7 @@ def eval(model,
start = time.time() start = time.time()
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
if model_type == "kie": elif model_type == "kie":
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment