"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "bd32a9002128e65443c6834e205b7d9b13be0f97"
Commit f2d98c5e authored by weishengyu's avatar weishengyu
Browse files

add style_text_rec

parent b1623d69
from engine.synthesisers import DatasetSynthesiser
def synth_dataset():
dataset_synthesiser = DatasetSynthesiser()
dataset_synthesiser.synth_dataset()
if __name__ == '__main__':
synth_dataset()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import sys
import glob
from engine.synthesisers import ImageSynthesiser
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def synth_image():
image_synthesiser = ImageSynthesiser()
img = cv2.imread("examples/style_images/1.jpg")
corpus = "PaddleOCR"
language = "en"
synth_result = image_synthesiser.synth_image(corpus, img, language)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
cv2.imwrite("fake_fusion.jpg", fake_fusion)
cv2.imwrite("fake_text.jpg", fake_text)
cv2.imwrite("fake_bg.jpg", fake_bg)
def batch_synth_images():
image_synthesiser = ImageSynthesiser()
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
save_path = "./output_data/"
corpus_list = []
with open(corpus_file, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode("utf-8").strip("\n").split("\t")
corpus_list.append(substr)
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
corpus_num = len(corpus_list)
style_img_num = len(style_img_list)
for cno in range(corpus_num):
for sno in range(style_img_num):
corpus, lang = corpus_list[cno]
style_img_path = style_img_list[sno]
img = cv2.imread(style_img_path)
synth_result = image_synthesiser.synth_image(corpus, img, lang)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
for tp in range(2):
if tp == 0:
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
else:
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
cv2.imwrite("%s_input_style.jpg" % prefix, img)
print(cno, corpus_num, sno, style_img_num)
if __name__ == '__main__':
# batch_synth_images()
synth_image()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from collections import OrderedDict
from argparse import ArgumentParser, RawDescriptionHelpFormatter
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
return args
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext = os.path.splitext(file_path)[1]
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
with open(file_path, 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def gen_config():
base_config = {
"Global": {
"algorithm": "SRNet",
"use_gpu": True,
"start_epoch": 1,
"stage1_epoch_num": 100,
"stage2_epoch_num": 100,
"log_smooth_window": 20,
"print_batch_step": 2,
"save_model_dir": "./output/SRNet",
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
"vgg_load_static_pretrain": True
},
"Architecture": {
"model_type": "data_aug",
"algorithm": "SRNet",
"net_g": {
"name": "srnet_net_g",
"encode_dim": 64,
"norm": "batch",
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
"use_dilation": 1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator": {
"name": "srnet_bg_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
"fusion_discriminator": {
"name": "srnet_fusion_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
}
},
"Loss": {
"lamb": 10,
"perceptual_lamb": 1,
"muvar_lamb": 50,
"style_lamb": 500
},
"Optimizer": {
"name": "Adam",
"learning_rate": {
"name": "lambda",
"lr": 0.0002,
"lr_decay_iters": 50
},
"beta1": 0.5,
"beta2": 0.999,
},
"Train": {
"batch_size_per_card": 8,
"num_workers_per_card": 4,
"dataset": {
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
"transforms": [{
"DecodeImage": {
"to_rgb": True,
"to_np": False,
"channel_first": False
}
}, {
"NormalizeImage": {
"scale": 1. / 255.,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": None
}
}, {
"ToCHWImage": None
}]
}
}
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
if __name__ == '__main__':
gen_config()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle
__all__ = ['load_dygraph_pretrain']
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not os.path.exists(path + '.pdparams'):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(path))
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def compute_mean_covariance(img):
batch_size = img.shape[0]
channel_num = img.shape[1]
height = img.shape[2]
width = img.shape[3]
num_pixels = height * width
# batch_size * channel_num * 1 * 1
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
# batch_size * channel_num * num_pixels
img_hat = img - mu.expand_as(img)
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
# batch_size * num_pixels * channel_num
img_hat_transpose = img_hat.transpose([0, 2, 1])
# batch_size * channel_num * channel_num
covariance = paddle.bmm(img_hat, img_hat_transpose)
covariance = covariance / num_pixels
return mu, covariance
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import errno
import paddle
def get_check_global_params(mode):
check_params = [
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
'character_type', 'loss_type'
]
if mode == "train_eval":
check_params = check_params + [
'train_batch_size_per_card', 'test_batch_size_per_card'
]
elif mode == "test":
check_params = check_params + ['test_batch_size_per_card']
return check_params
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
if use_gpu:
try:
if not paddle.is_compiled_with_cuda():
print(err)
sys.exit(1)
except:
print("Fail to check gpu state.")
sys.exit(1)
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment