Commit 8259d256 authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into fix_vqa

parents 5feb969e 3b646d4f
......@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
......
......@@ -33,7 +33,7 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMFor
from xfun import XFUNDataset
from losses import SERLoss
from utils import parse_args, get_bio_label_maps, print_arguments
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger
......
......@@ -15,7 +15,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, draw_re_results
from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
......
......@@ -14,6 +14,10 @@
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
......@@ -22,7 +26,7 @@ from copy import deepcopy
import paddle
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
......
......@@ -14,6 +14,10 @@
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
......@@ -25,9 +29,16 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
MODELS = {
'LayoutXLM':
......
......@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference
from utils import parse_args, get_image_file_list, draw_re_results
from vqa_utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor
......
......@@ -27,7 +27,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator
from eval_re import evaluate
......
......@@ -32,7 +32,7 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate
from losses import SERLoss
from ppocr.utils.logging import get_logger
......
......@@ -126,9 +126,6 @@ def main():
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
draw_det_res(boxes, config, src_img, file, save_det_path)
logger.info("success!")
......
......@@ -33,8 +33,9 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
import time
def read_class_list(filepath):
......@@ -80,7 +81,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
save_kie_path = os.path.dirname(config['Global'][
'save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png")
......@@ -93,7 +95,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
load_model(config, model)
# create data ops
transforms = []
......@@ -111,10 +113,15 @@ def main():
os.makedirs(os.path.dirname(save_res_path))
model.eval()
warmup_times = 0
count_t = []
with open(save_res_path, "wb") as fout:
with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines()
for index, data_line in enumerate(lines):
if index == 10:
warmup_t = time.time()
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path, label = data_dir + "/" + substr[0], substr[1]
......@@ -122,16 +129,23 @@ def main():
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
st = time.time()
batch = transform(data, ops)
batch_pred = [0] * len(batch)
for i in range(len(batch)):
batch_pred[i] = paddle.to_tensor(
np.expand_dims(
batch[i], axis=0))
st = time.time()
node, edge = model(batch_pred)
node = F.softmax(node, -1)
count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!")
logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t)))
ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
logger.info("The ips is {} images/s".format(ips))
if __name__ == '__main__':
......
......@@ -227,10 +227,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
train_start = time.time()
# use amp
......@@ -243,6 +239,8 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type == "kie":
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
......@@ -403,7 +401,7 @@ def eval(model,
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
elif model_type == "kie":
preds = model(batch)
else:
preds = model(images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment