Commit 41a1b292 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 9471054e 3d30899b
......@@ -2551,7 +2551,7 @@
"\n",
"Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n",
"\n",
"## 4.1 Padde Serving简介\n",
"## 4.1 Paddle Serving简介\n",
"\n",
"Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n",
"\n",
......@@ -42,12 +42,14 @@ __all__ = [
]
SUPPORT_DET_MODEL = ['DB']
VERSION = '2.3.0.2'
VERSION = '2.4'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
MODEL_URLS = {
'OCR': {
'PP-OCRv2': {
......@@ -190,6 +192,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--ocr_version",
type=str,
choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. '
......@@ -198,6 +201,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--structure_version",
type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
......@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError
model_urls = MODEL_URLS[type]
if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error('{} models is not support, we only support {}'.format(
model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1)
if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning(
'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error(
......@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
......@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem):
ocr with paddleocr
args:
img: img for ocr, support ndarray, img_path and list or ndarray
det: use text detection or not, if false, only rec will be exec. default is True
rec: use text recognition or not, if false, only det will be exec. default is True
det: use text detection or not. If false, only rec will be exec. Default is True
rec: use text recognition or not. If false, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
"""
assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True:
......@@ -398,6 +399,8 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
......
......@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import os
import sys
import numpy as np
import skimage
import paddle
import signal
import random
......@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle=shuffle,
drop_last=drop_last)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory)
use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
......
......@@ -15,20 +15,20 @@
import paddle
import numbers
import numpy as np
from collections import defaultdict
class DataCollator:
class DictCollator(object):
"""
data batch
"""
def __call__(self, batch):
data_dict = {}
# todo:support batch operators
data_dict = defaultdict(list)
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
......@@ -36,3 +36,23 @@ class DataCollator:
for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict
class ListCollator(object):
"""
data batch
"""
def __call__(self, batch):
# todo:support batch operators
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
for idx, v in enumerate(sample):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
......@@ -34,6 +34,8 @@ from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
from .vqa import *
def transform(data, ops=None):
""" transform """
......
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
import numpy as np
import string
from shapely.geometry import LineString, Point, Polygon
......@@ -736,7 +737,7 @@ class TableLabelEncode(object):
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
% char_or_elem
return idx
......@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode):
def get_ignored_tokens(self):
return [self.padding_idx]
class VQATokenLabelEncode(object):
"""
Label encode for NLP VQA methods
"""
def __init__(self,
class_path,
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
infer_mode=False,
ocr_engine=None,
**kwargs):
super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = {
'LayoutXLM': {
'class': LayoutXLMTokenizer,
'pretrained_model': 'layoutxlm-base-uncased'
},
'LayoutLM': {
'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased'
}
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config['class'].from_pretrained(
tokenizer_config['pretrained_model'])
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
def __call__(self, data):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
height, width, _ = data['image'].shape
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
segment_offset_id = []
gt_label_list = []
entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
if train_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
bbox = self._smooth_box(info["bbox"], height, width)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
# parse label
if not self.infer_mode:
label = info['label']
gt_label = self._parse_label(label, encode_res)
# construct entities for re
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
if not self.infer_mode:
gt_label_list.extend(gt_label)
data['input_ids'] = input_ids_list
data['token_type_ids'] = token_type_ids_list
data['bbox'] = bbox_list
data['attention_mask'] = [1] * len(input_ids_list)
data['labels'] = gt_label_list
data['segment_offset_id'] = segment_offset_id
data['tokenizer_params'] = dict(
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities
if train_re:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
return info_dict["ocr_info"]
def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
return bbox
def _parse_label(self, label, encode_res):
gt_label = []
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
return gt_label
......@@ -23,7 +23,6 @@ import sys
import six
import cv2
import numpy as np
import fasttext
class DecodeImage(object):
......@@ -136,6 +135,7 @@ class ToCHWImage(object):
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
......@@ -170,17 +170,19 @@ class Resize(object):
def __call__(self, data):
img = data['image']
text_polys = data['polys']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
]
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
encoded_inputs_all = []
seq_len = len(data['input_ids'])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
encoded_inputs_example[key] = data[key]
else:
encoded_inputs_example[key] = data[key][chunk_beg:
chunk_end]
else:
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all[0]
class VQAReTokenChunk(object):
def __init__(self,
max_seq_len=512,
entities_labels=None,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.entities_labels = {
'HEADER': 0,
'QUESTION': 1,
'ANSWER': 2
} if entities_labels is None else entities_labels
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
entities = data.pop('entities')
relations = data.pop('relations')
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
item[key] = data[key]
else:
item[key] = data[key][index:index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + self.max_seq_len and
index <= entity["end"] < index + self.max_seq_len):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + self.max_seq_len
and index <= relation["end_index"] <
index + self.max_seq_len):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all[0]
def reformat(self, data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
class VQATokenPad(object):
def __init__(self,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids
self.truncation_strategy = truncation_strategy
self.return_overflowing_tokens = return_overflowing_tokens
self.return_special_tokens_mask = return_special_tokens_mask
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.infer_mode = infer_mode
def __call__(self, data):
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
"input_ids"]) < self.max_seq_len
if needs_to_be_padded:
if 'tokenizer_params' in data:
tokenizer_params = data.pop('tokenizer_params')
else:
tokenizer_params = dict(
padding_side='right', pad_token_type_id=0, pad_token_id=1)
difference = self.max_seq_len - len(data["input_ids"])
if tokenizer_params['padding_side'] == 'right':
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data[
"input_ids"]) + [0] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
data["token_type_ids"] +
[tokenizer_params['pad_token_type_id']] * difference)
if self.return_special_tokens_mask:
data["special_tokens_mask"] = data[
"special_tokens_mask"] + [1] * difference
data["input_ids"] = data["input_ids"] + [
tokenizer_params['pad_token_id']
] * difference
if not self.infer_mode:
data["labels"] = data[
"labels"] + [self.pad_token_label_id] * difference
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
elif tokenizer_params['padding_side'] == 'left':
if self.return_attention_mask:
data["attention_mask"] = [0] * difference + [
1
] * len(data["input_ids"])
if self.return_token_type_ids:
data["token_type_ids"] = (
[tokenizer_params['pad_token_type_id']] * difference +
data["token_type_ids"])
if self.return_special_tokens_mask:
data["special_tokens_mask"] = [
1
] * difference + data["special_tokens_mask"]
data["input_ids"] = [tokenizer_params['pad_token_id']
] * difference + data["input_ids"]
if not self.infer_mode:
data["labels"] = [self.pad_token_label_id
] * difference + data["labels"]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data["input_ids"])
for key in data:
if key in [
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode:
if key != 'labels':
length = min(len(data[key]), self.max_seq_len)
data[key] = data[key][:length]
else:
continue
data[key] = np.array(data[key], dtype='int64')
return data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQAReTokenRelation(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
"""
build relations
"""
entities = data['entities']
relations = data['relations']
id2label = data.pop('id2label')
empty_entity = data.pop('empty_entity')
entity_id_to_index_map = data.pop('entity_id_to_index_map')
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": self.get_relation_span(rel, entities)[0],
"end_index": self.get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
data['relations'] = relations
return data
def get_relation_span(self, rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
......@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
......
......@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
......
......@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
......@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset):
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
......@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset):
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
......
......@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
......@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
......@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines)
return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
......@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
......@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset):
data['image'] = img
data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1]!=4:
if data is None or data['polys'].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
......@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
......
......@@ -16,6 +16,9 @@ import copy
import paddle
import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
......@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def forward(self, x, y):
return self.loss_func(x, y)
class LossFromOutput(nn.Layer):
def __init__(self, key='loss', reduction='none'):
super().__init__()
self.key = key
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
class SERLoss(nn.Layer):
class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask):
def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape(
active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
else:
loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
return loss
predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ]))
return {'loss': loss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment