Commit 4f8b5113 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents d73ed79c 370f0fef
...@@ -23,8 +23,10 @@ from PIL import Image ...@@ -23,8 +23,10 @@ from PIL import Image
import paddle import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddleocr import PaddleOCR
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
...@@ -48,48 +50,45 @@ def parse_ocr_info_for_ser(ocr_result): ...@@ -48,48 +50,45 @@ def parse_ocr_info_for_ser(ocr_result):
return ocr_info return ocr_info
@paddle.no_grad() class SerPredictor(object):
def infer(args): def __init__(self, args):
os.makedirs(args.output_dir, exist_ok=True) self.max_seq_length = args.max_seq_length
# init token and model # init ser token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) self.tokenizer = LayoutXLMTokenizer.from_pretrained(
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path) args.model_name_or_path)
model.eval() self.model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) self.model.eval()
label2id_map_for_draw = dict()
# init ocr_engine
self.ocr_engine = PaddleOCR(
rec_model_dir=args.ocr_rec_model_dir,
det_model_dir=args.ocr_det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
label2id_map, self.id2label_map = get_bio_label_maps(
args.label_map_path)
self.label2id_map_for_draw = dict()
for key in label2id_map: for key in label2id_map:
if key.startswith("I-"): if key.startswith("I-"):
label2id_map_for_draw[key] = label2id_map["B" + key[1:]] self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else: else:
label2id_map_for_draw[key] = label2id_map[key] self.label2id_map_for_draw[key] = label2id_map[key]
# get infer img list def __call__(self, img):
infer_imgs = get_image_file_list(args.infer_imgs) ocr_result = self.ocr_engine.ocr(img, cls=False)
ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
args.ocr_det_model_dir)
# loop for infer
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
ocr_result = ocr_engine.ocr(img_path, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result) ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess( inputs = preprocess(
tokenizer=tokenizer, tokenizer=self.tokenizer,
ori_img=img, ori_img=img,
ocr_info=ocr_info, ocr_info=ocr_info,
max_seq_len=args.max_seq_length) max_seq_len=self.max_seq_length)
outputs = model( outputs = self.model(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
bbox=inputs["bbox"], bbox=inputs["bbox"],
image=inputs["image"], image=inputs["image"],
...@@ -97,25 +96,36 @@ def infer(args): ...@@ -97,25 +96,36 @@ def infer(args):
attention_mask=inputs["attention_mask"]) attention_mask=inputs["attention_mask"])
preds = outputs[0] preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, id2label_map) preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info( ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds, ocr_info, inputs["segment_offset_id"], preds,
label2id_map_for_draw) self.label2id_map_for_draw)
return ocr_info, inputs
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_engine = SerPredictor(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
result, _ = ser_engine(img)
fout.write(img_path + "\t" + json.dumps( fout.write(img_path + "\t" + json.dumps(
{ {
"ocr_info": ocr_info, "ser_resule": result,
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info) img_res = draw_ser_results(img, result)
cv2.imwrite( cv2.imwrite(
os.path.join(args.output_dir, os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res) "_ser.jpg"), img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference
from utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor
def make_input(ser_input, ser_result, max_seq_len=512):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
entities = ser_input['entities'][0]
assert len(entities) == len(ser_result)
# entities
start = []
end = []
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_result, entities)):
if res['pred'] == 'O':
continue
entity_idx_dict[len(start)] = i
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
# relations
head = []
tail = []
for i in range(len(entities["label"])):
for j in range(len(entities["label"])):
if entities["label"][i] == 1 and entities["label"][j] == 2:
head.append(i)
tail.append(j)
relations = dict(head=head, tail=tail)
batch_size = ser_input["input_ids"].shape[0]
entities_batch = []
relations_batch = []
for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
ser_input['entities'] = entities_batch
ser_input['relations'] = relations_batch
ser_input.pop('segment_offset_id')
return ser_input, entity_idx_dict
class SerReSystem(object):
def __init__(self, args):
self.ser_engine = SerPredictor(args)
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.re_model_name_or_path)
self.model = LayoutXLMForRelationExtraction.from_pretrained(
args.re_model_name_or_path)
self.model.eval()
def __call__(self, img):
ser_result, ser_inputs = self.ser_engine(img)
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
re_result = self.model(**re_input)
pred_relations = re_result['pred_relations'][0]
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relation in pred_relations:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
return result
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_re_engine = SerReSystem(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
result = ser_re_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"result": result,
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_re.jpg"), img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import numpy as np
import logging
logger = logging.getLogger(__name__)
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(
os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(
folder,
max(checkpoints,
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
def re_score(pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
from ppocr.utils.logging import get_logger
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def cal_metric(re_preds, re_labels, entities):
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]],
entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
return re_metrics
def evaluate(model, eval_dataloader, logger, prefix=""):
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
re_preds = []
re_labels = []
entities = []
eval_loss = 0.0
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
loss = outputs['loss'].mean().item()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss))
eval_loss += loss
re_preds.extend(outputs['pred_relations'])
re_labels.extend(batch['relations'])
entities.extend(batch['entities'])
re_metrics = cal_metric(re_preds, re_labels, entities)
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
model.train()
return re_metrics
def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3)
set_seed(args.seed)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction(model, dropout=None)
# dist mode
if paddle.distributed.get_world_size() > 1:
model = paddle.distributed.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
args.train_batch_size = args.per_gpu_train_batch_size * \
max(1, paddle.distributed.get_world_size())
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=8,
use_shared_memory=True,
collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
grad_clip=grad_clip,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = {}".format(len(train_dataset)))
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
logger.info(" Instantaneous batch size per GPU = {}".format(
args.per_gpu_train_batch_size))
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}".
format(args.train_batch_size * paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0
model.clear_gradients()
train_dataloader_len = len(train_dataloader)
best_metirc = {'f1': 0}
model.train()
for epoch in range(int(args.num_train_epochs)):
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
# model outputs are always tuple in ppnlp (see doc)
loss = outputs['loss']
loss = loss.mean()
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
format(epoch, args.num_train_epochs, step, train_dataloader_len,
global_step, np.mean(loss.numpy()), optimizer.get_lr()))
loss.backward()
optimizer.step()
optimizer.clear_grad()
# lr_scheduler.step() # Update learning rate schedule
global_step += 1
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
global_step % args.eval_steps == 0):
# Log metrics
if (paddle.distributed.get_rank() == 0 and args.
evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(model, eval_dataloader, logger)
if results['f1'] > best_metirc['f1']:
best_metirc = results
output_dir = os.path.join(args.output_dir,
"checkpoint-best")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir,
"training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("eval results: {}".format(results))
logger.info("best_metirc: {}".format(best_metirc))
if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
global_step % args.save_steps == 0):
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-latest")
os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("best_metirc: {}".format(best_metirc))
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
train(args)
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random import random
import copy import copy
import logging import logging
...@@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args from utils import parse_args
from utils import get_bio_label_maps from utils import get_bio_label_maps
from utils import print_arguments
logger = logging.getLogger(__name__) from ppocr.utils.logging import get_logger
def set_seed(args): def set_seed(args):
...@@ -38,17 +44,8 @@ def set_seed(args): ...@@ -38,17 +44,8 @@ def set_seed(args):
def train(args): def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig( logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
filename=os.path.join(args.output_dir, "train.log") print_arguments(args, logger)
if paddle.distributed.get_rank() == 0 else None,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
if paddle.distributed.get_rank() == 0 else logging.WARN, )
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
...@@ -136,10 +133,10 @@ def train(args): ...@@ -136,10 +133,10 @@ def train(args):
loss = outputs[0] loss = outputs[0]
loss = loss.mean() loss = loss.mean()
logger.info( logger.info(
"[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ". "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
format(epoch_id, args.num_train_epochs, step, format(epoch_id, args.num_train_epochs, step,
len(train_dataloader), len(train_dataloader), global_step,
lr_scheduler.get_lr(), loss.numpy()[0])) loss.numpy()[0], lr_scheduler.get_lr()))
loss.backward() loss.backward()
tr_loss += loss.item() tr_loss += loss.item()
...@@ -154,13 +151,9 @@ def train(args): ...@@ -154,13 +151,9 @@ def train(args):
# Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
if paddle.distributed.get_rank( if paddle.distributed.get_rank(
) == 0 and args.evaluate_during_training: ) == 0 and args.evaluate_during_training:
results, _ = evaluate( results, _ = evaluate(args, model, tokenizer, label2id_map,
args, id2label_map, pad_token_label_id,
model, logger)
tokenizer,
label2id_map,
id2label_map,
pad_token_label_id, )
if best_metrics is None or results["f1"] >= best_metrics[ if best_metrics is None or results["f1"] >= best_metrics[
"f1"]: "f1"]:
...@@ -204,6 +197,7 @@ def evaluate(args, ...@@ -204,6 +197,7 @@ def evaluate(args,
label2id_map, label2id_map,
id2label_map, id2label_map,
pad_token_label_id, pad_token_label_id,
logger,
prefix=""): prefix=""):
eval_dataset = XFUNDataset( eval_dataset = XFUNDataset(
tokenizer, tokenizer,
...@@ -299,15 +293,6 @@ def evaluate(args, ...@@ -299,15 +293,6 @@ def evaluate(args,
return results, preds_list return results, preds_list
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
print_arguments(args)
train(args) train(args)
...@@ -24,8 +24,6 @@ import paddle ...@@ -24,8 +24,6 @@ import paddle
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from paddleocr import PaddleOCR
def get_bio_label_maps(label_map_path): def get_bio_label_maps(label_map_path):
with open(label_map_path, "r") as fin: with open(label_map_path, "r") as fin:
...@@ -66,9 +64,9 @@ def get_image_file_list(img_file): ...@@ -66,9 +64,9 @@ def get_image_file_list(img_file):
def draw_ser_results(image, def draw_ser_results(image,
ocr_results, ocr_results,
font_path="../doc/fonts/simfang.ttf", font_path="../../doc/fonts/simfang.ttf",
font_size=18): font_size=18):
np.random.seed(0) np.random.seed(2021)
color = (np.random.permutation(range(255)), color = (np.random.permutation(range(255)),
np.random.permutation(range(255)), np.random.permutation(range(255)),
np.random.permutation(range(255))) np.random.permutation(range(255)))
...@@ -82,38 +80,64 @@ def draw_ser_results(image, ...@@ -82,38 +80,64 @@ def draw_ser_results(image,
draw = ImageDraw.Draw(img_new) draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8") font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results: for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map: if ocr_info["pred_id"] not in color_map:
continue continue
color = color_map[ocr_info["pred_id"]] color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline # draw ocr results outline
bbox = ocr_info["bbox"]
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color) draw.rectangle(bbox, fill=color)
# draw ocr results # draw ocr results
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
start_y = max(0, bbox[0][1] - font_size) start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0] tw = font.getsize(text)[0]
draw.rectangle( draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
start_y + font_size)],
fill=(0, 0, 255)) fill=(0, 0, 255))
draw.text( draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
(bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_re_results(image,
result,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
def build_ocr_engine(rec_model_dir, det_model_dir): draw.line([center_head, center_tail], fill=color_line, width=5)
ocr_engine = PaddleOCR(
rec_model_dir=rec_model_dir, img_new = Image.blend(image, img_new, 0.5)
det_model_dir=det_model_dir, return np.array(img_new)
use_angle_cls=False)
return ocr_engine
# pad sentences # pad sentences
...@@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512): ...@@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512):
truncate is often used in training process truncate is often used in training process
""" """
for key in encoded_inputs: for key in encoded_inputs:
if key == 'entities':
encoded_inputs[key] = [encoded_inputs[key]]
continue
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
...@@ -184,14 +211,14 @@ def preprocess( ...@@ -184,14 +211,14 @@ def preprocess(
height = ori_img.shape[0] height = ori_img.shape[0]
width = ori_img.shape[1] width = ori_img.shape[1]
img = cv2.resize(ori_img, img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = [] segment_offset_id = []
words_list = [] words_list = []
bbox_list = [] bbox_list = []
input_ids_list = [] input_ids_list = []
token_type_ids_list = [] token_type_ids_list = []
entities = []
for info in ocr_info: for info in ocr_info:
# x1, y1, x2, y2 # x1, y1, x2, y2
...@@ -211,6 +238,13 @@ def preprocess( ...@@ -211,6 +238,13 @@ def preprocess(
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend([bbox] * len(encode_res["input_ids"]))
...@@ -222,6 +256,7 @@ def preprocess( ...@@ -222,6 +256,7 @@ def preprocess(
"token_type_ids": token_type_ids_list, "token_type_ids": token_type_ids_list,
"bbox": bbox_list, "bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list), "attention_mask": [1] * len(input_ids_list),
"entities": entities
} }
encoded_inputs = pad_sentences( encoded_inputs = pad_sentences(
...@@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, ...@@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
return ocr_info return ocr_info
def print_arguments(args, logger=None):
print_func = logger.info if logger is not None else print
"""print arguments"""
print_func('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print_func('%s: %s' % (arg, value))
print_func('------------------------------------------------')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
# yapf: disable # yapf: disable
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,) parser.add_argument("--model_name_or_path",
parser.add_argument("--train_data_dir", default=None, type=str, required=False,) default=None, type=str, required=True,)
parser.add_argument("--train_label_path", default=None, type=str, required=False,) parser.add_argument("--re_model_name_or_path",
parser.add_argument("--eval_data_dir", default=None, type=str, required=False,) default=None, type=str, required=False,)
parser.add_argument("--eval_label_path", default=None, type=str, required=False,) parser.add_argument("--train_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--train_label_path", default=None,
type=str, required=False,)
parser.add_argument("--eval_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--eval_label_path", default=None,
type=str, required=False,)
parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",) parser.add_argument("--evaluate_during_training", action="store_true",)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) parser.add_argument("--per_gpu_train_batch_size", default=8,
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",) type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",) parser.add_argument("--per_gpu_eval_batch_size", default=8,
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",) type=int, help="Batch size per GPU/CPU for eval.",)
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",) parser.add_argument("--learning_rate", default=5e-5,
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",) type=float, help="The initial learning rate for Adam.",)
parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",) parser.add_argument("--weight_decay", default=0.0,
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",) type=float, help="Weight decay if we apply some.",)
parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",) parser.add_argument("--adam_epsilon", default=1e-8,
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",) type=float, help="Epsilon for Adam optimizer.",)
parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",) parser.add_argument("--max_grad_norm", default=1.0,
type=float, help="Max gradient norm.",)
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.",)
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.",)
parser.add_argument("--eval_steps", type=int, default=10,
help="eval every X updates steps.",)
parser.add_argument("--save_steps", type=int, default=50,
help="Save checkpoint every X updates steps.",)
parser.add_argument("--seed", type=int, default=2048,
help="random seed for initialization",)
parser.add_argument("--ocr_rec_model_dir", default=None, type=str, ) parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
parser.add_argument("--ocr_det_model_dir", default=None, type=str, ) parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) parser.add_argument(
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False) parser.add_argument("--infer_imgs", default=None, type=str, required=False)
parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results") parser.add_argument("--ocr_json_path", default=None,
type=str, required=False, help="ocr prediction results")
# yapf: enable # yapf: enable
args = parser.parse_args() args = parser.parse_args()
return args return args
===========================train_params=========================== ===========================train_params===========================
model_name:PPOCRv2_ocr_det model_name:PPOCRv2_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -26,7 +26,7 @@ null:null ...@@ -26,7 +26,7 @@ null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_export: fpgm_export:
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
...@@ -34,7 +34,7 @@ distill_export:null ...@@ -34,7 +34,7 @@ distill_export:null
export1:null export1:null
export2:null export2:null
inference_dir:Student inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_rec_infer/ infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null infer_export:null
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py inference:tools/infer/predict_rec.py
...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py ...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True --use_tensorrt:False|True
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--rec_model_dir: --rec_model_dir:
--image_dir:/inference/rec_inference --image_dir:./inference/rec_inference
null:null null:null
--benchmark:True --benchmark:True
null:null null:null
......
...@@ -6,15 +6,15 @@ Global.use_gpu:True|True ...@@ -6,15 +6,15 @@ Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:pact_train trainer:pact_train
norm_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o norm_train:null
pact_train:null pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null null:null
...@@ -27,14 +27,14 @@ null:null ...@@ -27,14 +27,14 @@ null:null
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.pretrained_model:
norm_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o norm_export:null
quant_export: quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_export: fpgm_export: null
distill_export:null distill_export:null
export1:null export1:null
export2:null export2:null
inference_dir:Student inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_rec_infer/ infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null infer_export:null
infer_quant:True infer_quant:True
inference:tools/infer/predict_rec.py inference:tools/infer/predict_rec.py
...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py ...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True --use_tensorrt:False|True
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--rec_model_dir: --rec_model_dir:
--image_dir:/inference/rec_inference --image_dir:./inference/rec_inference
null:null null:null
--benchmark:True --benchmark:True
null:null null:null
......
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=100|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:null
......
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=20|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:null
...@@ -26,7 +26,7 @@ null:null ...@@ -26,7 +26,7 @@ null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
fpgm_export:null fpgm_export:null
......
...@@ -28,7 +28,7 @@ null:null ...@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.checkpoints: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -ctest_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null fpgm_export:null
distill_export:null distill_export:null
export1:null export1:null
......
...@@ -12,22 +12,22 @@ train_model_name:latest ...@@ -12,22 +12,22 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train|pact_train|fpgm_export trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o quant_train:null
fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o fpgm_train:null
distill_train:null distill_train:null
null:null null:null
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
null:null null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.pretrained_model:
norm_export:tools/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:null quant_export:null
fpgm_export:null fpgm_export:null
distill_export:null distill_export:null
......
...@@ -35,7 +35,7 @@ export1:null ...@@ -35,7 +35,7 @@ export1:null
export2:null export2:null
## ##
train_model:./inference/det_r50_vd_pse/best_accuracy train_model:./inference/det_r50_vd_pse/best_accuracy
infer_export:tools/export_model.py -c test_tipc/cconfigs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
infer_quant:False infer_quant:False
inference:tools/infer/predict_det.py inference:tools/infer/predict_det.py
--use_gpu:True|False --use_gpu:True|False
......
...@@ -62,7 +62,7 @@ Train: ...@@ -62,7 +62,7 @@ Train:
data_dir: ./train_data/icdar2015/text_localization/ data_dir: ./train_data/icdar2015/text_localization/
label_file_list: label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [0.1, 0.45, 0.3, 0.15] ratio_list: [1.0]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py ...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
null:null null:null
--benchmark:True --benchmark:True
null:null --det_algorithm:SAST
...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py ...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
null:null null:null
--benchmark:True --benchmark:True
null:null --det_algorithm:SAST
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500 Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=500
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=14 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=14
Global.pretrained_model:null Global.pretrained_model:null
...@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py ...@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py
--enable_mkldnn:True|False --enable_mkldnn:True|False
--cpu_threads:1|6 --cpu_threads:1|6
--rec_batch_num:1 --rec_batch_num:1
--use_tensorrt:False|True --use_tensorrt:False
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--e2e_model_dir: --e2e_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment