Commit 41a1b292 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 9471054e 3d30899b
...@@ -2551,7 +2551,7 @@ ...@@ -2551,7 +2551,7 @@
"\n", "\n",
"Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n", "Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n",
"\n", "\n",
"## 4.1 Padde Serving简介\n", "## 4.1 Paddle Serving简介\n",
"\n", "\n",
"Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n", "Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n",
"\n", "\n",
...@@ -42,12 +42,14 @@ __all__ = [ ...@@ -42,12 +42,14 @@ __all__ = [
] ]
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.3.0.2' VERSION = '2.4'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
MODEL_URLS = { MODEL_URLS = {
'OCR': { 'OCR': {
'PP-OCRv2': { 'PP-OCRv2': {
...@@ -190,6 +192,7 @@ def parse_args(mMain=True): ...@@ -190,6 +192,7 @@ def parse_args(mMain=True):
parser.add_argument( parser.add_argument(
"--ocr_version", "--ocr_version",
type=str, type=str,
choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2', default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: ' help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. ' '1. PP-OCRv2 Support Chinese detection and recognition model. '
...@@ -198,6 +201,7 @@ def parse_args(mMain=True): ...@@ -198,6 +201,7 @@ def parse_args(mMain=True):
parser.add_argument( parser.add_argument(
"--structure_version", "--structure_version",
type=str, type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE', default='STRUCTURE',
help='Model version, the current model support list is as follows:' help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.') ' 1. STRUCTURE Support en table structure model.')
...@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang): ...@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else: else:
raise NotImplementedError raise NotImplementedError
model_urls = MODEL_URLS[type] model_urls = MODEL_URLS[type]
if version not in model_urls: if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
if model_type not in model_urls[version]: if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]: if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error('{} models is not support, we only support {}'.format( logger.error('{} models is not support, we only support {}'.format(
model_type, model_urls[DEFAULT_MODEL_VERSION].keys())) model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1) sys.exit(-1)
if lang not in model_urls[version][model_type]: if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]: if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning(
'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error( logger.error(
...@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem):
""" """
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
params.use_gpu = check_gpu(params.use_gpu) params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
...@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem):
ocr with paddleocr ocr with paddleocr
args: args:
img: img for ocr, support ndarray, img_path and list or ndarray img: img for ocr, support ndarray, img_path and list or ndarray
det: use text detection or not, if false, only rec will be exec. default is True det: use text detection or not. If false, only rec will be exec. Default is True
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not. If false, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
...@@ -398,6 +399,8 @@ class PPStructure(OCRSystem): ...@@ -398,6 +399,8 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu) params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
......
...@@ -20,6 +20,7 @@ from __future__ import unicode_literals ...@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import os import os
import sys import sys
import numpy as np import numpy as np
import skimage
import paddle import paddle
import signal import signal
import random import random
...@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory) use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c # support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGINT, term_mp)
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import paddle import paddle
import numbers import numbers
import numpy as np import numpy as np
from collections import defaultdict
class DataCollator: class DictCollator(object):
""" """
data batch data batch
""" """
def __call__(self, batch): def __call__(self, batch):
data_dict = {} # todo:support batch operators
data_dict = defaultdict(list)
to_tensor_keys = [] to_tensor_keys = []
for sample in batch: for sample in batch:
for k, v in sample.items(): for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys: if k not in to_tensor_keys:
to_tensor_keys.append(k) to_tensor_keys.append(k)
...@@ -36,3 +36,23 @@ class DataCollator: ...@@ -36,3 +36,23 @@ class DataCollator:
for k in to_tensor_keys: for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k]) data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict return data_dict
class ListCollator(object):
"""
data batch
"""
def __call__(self, batch):
# todo:support batch operators
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
for idx, v in enumerate(sample):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
...@@ -34,6 +34,8 @@ from .sast_process import * ...@@ -34,6 +34,8 @@ from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import * from .gen_table_mask import *
from .vqa import *
def transform(data, ops=None): def transform(data, ops=None):
""" transform """ """ transform """
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import numpy as np import numpy as np
import string import string
from shapely.geometry import LineString, Point, Polygon from shapely.geometry import LineString, Point, Polygon
...@@ -736,7 +737,7 @@ class TableLabelEncode(object): ...@@ -736,7 +737,7 @@ class TableLabelEncode(object):
% beg_or_end % beg_or_end
else: else:
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
...@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode): ...@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode):
def get_ignored_tokens(self): def get_ignored_tokens(self):
return [self.padding_idx] return [self.padding_idx]
class VQATokenLabelEncode(object):
"""
Label encode for NLP VQA methods
"""
def __init__(self,
class_path,
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
infer_mode=False,
ocr_engine=None,
**kwargs):
super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = {
'LayoutXLM': {
'class': LayoutXLMTokenizer,
'pretrained_model': 'layoutxlm-base-uncased'
},
'LayoutLM': {
'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased'
}
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config['class'].from_pretrained(
tokenizer_config['pretrained_model'])
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
def __call__(self, data):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
height, width, _ = data['image'].shape
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
segment_offset_id = []
gt_label_list = []
entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
if train_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
bbox = self._smooth_box(info["bbox"], height, width)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
# parse label
if not self.infer_mode:
label = info['label']
gt_label = self._parse_label(label, encode_res)
# construct entities for re
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
if not self.infer_mode:
gt_label_list.extend(gt_label)
data['input_ids'] = input_ids_list
data['token_type_ids'] = token_type_ids_list
data['bbox'] = bbox_list
data['attention_mask'] = [1] * len(input_ids_list)
data['labels'] = gt_label_list
data['segment_offset_id'] = segment_offset_id
data['tokenizer_params'] = dict(
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities
if train_re:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
return info_dict["ocr_info"]
def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
return bbox
def _parse_label(self, label, encode_res):
gt_label = []
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
return gt_label
...@@ -23,7 +23,6 @@ import sys ...@@ -23,7 +23,6 @@ import sys
import six import six
import cv2 import cv2
import numpy as np import numpy as np
import fasttext
class DecodeImage(object): class DecodeImage(object):
...@@ -136,6 +135,7 @@ class ToCHWImage(object): ...@@ -136,6 +135,7 @@ class ToCHWImage(object):
class Fasttext(object): class Fasttext(object):
def __init__(self, path="None", **kwargs): def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path) self.fast_model = fasttext.load_model(path)
def __call__(self, data): def __call__(self, data):
...@@ -170,17 +170,19 @@ class Resize(object): ...@@ -170,17 +170,19 @@ class Resize(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
text_polys = data['polys'] if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img) img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = [] if 'polys' in data:
for box in text_polys: new_boxes = []
new_box = [] for box in text_polys:
for cord in box: new_box = []
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) for cord in box:
new_boxes.append(new_box) new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data return data
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
]
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
encoded_inputs_all = []
seq_len = len(data['input_ids'])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
encoded_inputs_example[key] = data[key]
else:
encoded_inputs_example[key] = data[key][chunk_beg:
chunk_end]
else:
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all[0]
class VQAReTokenChunk(object):
def __init__(self,
max_seq_len=512,
entities_labels=None,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.entities_labels = {
'HEADER': 0,
'QUESTION': 1,
'ANSWER': 2
} if entities_labels is None else entities_labels
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
entities = data.pop('entities')
relations = data.pop('relations')
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
item[key] = data[key]
else:
item[key] = data[key][index:index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + self.max_seq_len and
index <= entity["end"] < index + self.max_seq_len):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + self.max_seq_len
and index <= relation["end_index"] <
index + self.max_seq_len):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all[0]
def reformat(self, data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
class VQATokenPad(object):
def __init__(self,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids
self.truncation_strategy = truncation_strategy
self.return_overflowing_tokens = return_overflowing_tokens
self.return_special_tokens_mask = return_special_tokens_mask
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.infer_mode = infer_mode
def __call__(self, data):
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
"input_ids"]) < self.max_seq_len
if needs_to_be_padded:
if 'tokenizer_params' in data:
tokenizer_params = data.pop('tokenizer_params')
else:
tokenizer_params = dict(
padding_side='right', pad_token_type_id=0, pad_token_id=1)
difference = self.max_seq_len - len(data["input_ids"])
if tokenizer_params['padding_side'] == 'right':
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data[
"input_ids"]) + [0] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
data["token_type_ids"] +
[tokenizer_params['pad_token_type_id']] * difference)
if self.return_special_tokens_mask:
data["special_tokens_mask"] = data[
"special_tokens_mask"] + [1] * difference
data["input_ids"] = data["input_ids"] + [
tokenizer_params['pad_token_id']
] * difference
if not self.infer_mode:
data["labels"] = data[
"labels"] + [self.pad_token_label_id] * difference
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
elif tokenizer_params['padding_side'] == 'left':
if self.return_attention_mask:
data["attention_mask"] = [0] * difference + [
1
] * len(data["input_ids"])
if self.return_token_type_ids:
data["token_type_ids"] = (
[tokenizer_params['pad_token_type_id']] * difference +
data["token_type_ids"])
if self.return_special_tokens_mask:
data["special_tokens_mask"] = [
1
] * difference + data["special_tokens_mask"]
data["input_ids"] = [tokenizer_params['pad_token_id']
] * difference + data["input_ids"]
if not self.infer_mode:
data["labels"] = [self.pad_token_label_id
] * difference + data["labels"]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data["input_ids"])
for key in data:
if key in [
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode:
if key != 'labels':
length = min(len(data[key]), self.max_seq_len)
data[key] = data[key][:length]
else:
continue
data[key] = np.array(data[key], dtype='int64')
return data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQAReTokenRelation(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
"""
build relations
"""
entities = data['entities']
relations = data['relations']
id2label = data.pop('id2label')
empty_entity = data.pop('empty_entity')
entity_id_to_index_map = data.pop('entity_id_to_index_map')
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": self.get_relation_span(rel, entities)[0],
"end_index": self.get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
data['relations'] = relations
return data
def get_relation_span(self, rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): ...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np.random.shuffle(self.data_idx_order_list) np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir): def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
......
...@@ -49,6 +49,8 @@ class PGDataSet(Dataset): ...@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
......
...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): ...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): ...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset):
prob = self.img_select_prob[file_name] prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1): if prob < random.uniform(0, 1):
select_flag = False select_flag = False
if self.table_select_type: if self.table_select_type:
structure = info['html']['structure']['tokens'].copy() structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure) structure_str = ''.join(structure)
...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): ...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset):
table_type = "complex" table_type = "complex"
if table_type == "complex": if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): if self.table_select_prob < random.uniform(0, 1):
select_flag = False select_flag = False
if select_flag: if select_flag:
cells = info['html']['cells'].copy() cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy() structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure} data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path): if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path)) raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
......
...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset): ...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset): ...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
...@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset): ...@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self): def get_ext_data(self):
ext_data_num = 0 ext_data_num = 0
for op in self.ops: for op in self.ops:
...@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset): ...@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
...@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset): ...@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1]!=4: if data is None or data['polys'].shape[1] != 4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset): ...@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
......
...@@ -16,6 +16,9 @@ import copy ...@@ -16,6 +16,9 @@ import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss # det loss
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss from .det_east_loss import EASTLoss
...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss ...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss' 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer): ...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def forward(self, x, y): def forward(self, x, y):
return self.loss_func(x, y) return self.loss_func(x, y)
class LossFromOutput(nn.Layer):
def __init__(self, key='loss', reduction='none'):
super().__init__()
self.key = key
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,24 +12,31 @@ ...@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn from paddle import nn
class SERLoss(nn.Layer): class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes): def __init__(self, num_classes):
super().__init__() super().__init__()
self.loss_class = nn.CrossEntropyLoss() self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask): def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape( active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss] [-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss] active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels) loss = self.loss_class(active_outputs, active_labels)
else: else:
loss = self.loss_class( loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ])) predicts.reshape([-1, self.num_classes]),
return loss labels.reshape([-1, ]))
return {'loss': loss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment