Commit f1506916 authored by sugon_cxj's avatar sugon_cxj
Browse files

first commit

parent 55c28ed5
Pipeline #266 canceled with stages
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from extract_textpoint_slow import *
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list_fast(
p_score,
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
def pg_postprocess_slow(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext"
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import time
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
\ No newline at end of file
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/loss/iou.py
"""
import paddle
EPS = 1e-6
def iou_single(a, b, mask, n_class):
valid = mask == 1
a = a.masked_select(valid)
b = b.masked_select(valid)
miou = []
for i in range(n_class):
if a.shape == [0] and a.shape == b.shape:
inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0)
else:
inter = ((a == i).logical_and(b == i)).astype('float32')
union = ((a == i).logical_or(b == i)).astype('float32')
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0]
a = a.reshape([batch_size, -1])
b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1])
iou = paddle.zeros((batch_size, ), dtype='float32')
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce:
iou = paddle.mean(iou)
return iou
from .vdl_logger import VDLLogger
from .wandb_logger import WandbLogger
from .loggers import Loggers
import os
from abc import ABC, abstractmethod
class BaseLogger(ABC):
def __init__(self, save_dir):
self.save_dir = save_dir
os.makedirs(self.save_dir, exist_ok=True)
@abstractmethod
def log_metrics(self, metrics, prefix=None):
pass
@abstractmethod
def close(self):
pass
\ No newline at end of file
from .wandb_logger import WandbLogger
class Loggers(object):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def log_metrics(self, metrics, prefix=None, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, prefix=prefix, step=step)
def log_model(self, is_best, prefix, metadata=None):
for logger in self.loggers:
logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
def close(self):
for logger in self.loggers:
logger.close()
\ No newline at end of file
from .base_logger import BaseLogger
from visualdl import LogWriter
class VDLLogger(BaseLogger):
def __init__(self, save_dir):
super().__init__(save_dir)
self.vdl_writer = LogWriter(logdir=save_dir)
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
for k, v in updated_metrics.items():
self.vdl_writer.add_scalar(k, v, step)
def log_model(self, is_best, prefix, metadata=None):
pass
def close(self):
self.vdl_writer.close()
\ No newline at end of file
import os
from .base_logger import BaseLogger
class WandbLogger(BaseLogger):
def __init__(self,
project=None,
name=None,
id=None,
entity=None,
save_dir=None,
config=None,
**kwargs):
try:
import wandb
self.wandb = wandb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install wandb using `pip install wandb`"
)
self.project = project
self.name = name
self.id = id
self.save_dir = save_dir
self.config = config
self.kwargs = kwargs
self.entity = entity
self._run = None
self._wandb_init = dict(
project=self.project,
name=self.name,
id=self.id,
entity=self.entity,
dir=self.save_dir,
resume="allow"
)
self._wandb_init.update(**kwargs)
_ = self.run
if self.config:
self.run.config.update(self.config)
@property
def run(self):
if self._run is None:
if self.wandb.run is not None:
logger.info(
"There is a wandb run already in progress "
"and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()`"
"before instantiating `WandbLogger`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self._wandb_init)
return self._run
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
self.run.log(updated_metrics, step=step)
def log_model(self, is_best, prefix, metadata=None):
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
artifact.add_file(model_path, name="model_ckpt.pdparams")
aliases = [prefix]
if is_best:
aliases.append("best")
self.run.log_artifact(artifact, aliases=aliases)
def close(self):
self.run.finish()
\ No newline at end of file
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
logger.propagate = False
return logger
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
def points2polygon(points):
"""Convert k points to 1 polygon.
Args:
points (ndarray or list): A ndarray or a list of shape (2k)
that indicates k points.
Returns:
polygon (Polygon): A polygon object.
"""
if isinstance(points, list):
points = np.array(points)
assert isinstance(points, np.ndarray)
assert (points.size % 2 == 0) and (points.size >= 8)
point_mat = points.reshape([-1, 2])
return Polygon(point_mat)
def poly_intersection(poly_det, poly_gt, buffer=0.0001):
"""Calculate the intersection area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
intersection_area (float): The intersection area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
if buffer == 0:
poly_inter = poly_det & poly_gt
else:
poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
return poly_inter.area, poly_inter
def poly_union(poly_det, poly_gt):
"""Calculate the union area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
union_area (float): The union area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_det = poly_det.area
area_gt = poly_gt.area
area_inters, _ = poly_intersection(poly_det, poly_gt)
return area_det + area_gt - area_inters
def valid_boundary(x, with_score=True):
num = len(x)
if num < 8:
return False
if num % 2 == 0 and (not with_score):
return True
if num % 2 == 1 and with_score:
return True
return False
def boundary_iou(src, target):
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
Returns:
iou (float): The iou between two boundaries.
"""
assert valid_boundary(src, False)
assert valid_boundary(target, False)
src_poly = points2polygon(src)
target_poly = points2polygon(target)
return poly_iou(src_poly, target_poly)
def poly_iou(poly_det, poly_gt):
"""Calculate the IOU between two polygons.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
iou (float): The IOU between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_inters, _ = poly_intersection(poly_det, poly_gt)
area_union = poly_union(poly_det, poly_gt)
if area_union == 0:
return 0.0
return area_inters / area_union
def poly_nms(polygons, threshold):
assert isinstance(polygons, list)
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
keep_poly = []
index = [i for i in range(polygons.shape[0])]
while len(index) > 0:
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
iou_list = np.zeros((len(index), ))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)
return keep_poly
'
贿
2
0
8
-
7
>
:
]
,
蹿
1
3
!
诿
线
尿
|
;
H
.
/
*
忿
齿
西
5
4
亿
(
访
6
)
稿
s
u
[
9
岿
广
S
Y
F
D
A
P
T
湿
窿
@
丿
沿
使
绿
%
"
婿
r
=
饿
ˇ
q
÷
椿
寿
?
便
殿
J
l
&
驿
x
耀
仿
鸿
廿
z
±
e
t
§
姿
b
<
退
L
鹿
w
i
h
+
I
B
N
^
_
M
鱿
怀
}
~
Z
槿
C
o
E
f
\
屿
U
a
p
y
n
g
竿
Q
羿
O
宿
k
$
c
v
W
穿
×
轿
R
G
ˉ
d
°
K
X
m
涿
`
V
#
簿
{
j
·
Ë
¥
π
é
Λ
Ο
α
 
鴿
è
Ü
И
»
ä
ɔ
´
í
É
ʌ
Я
Й
粿
®
З
β
á
Ó
ò
貿
𣇉
г
楿
滿
Φ
ε
ü
調
ˋ
ā
ú
ó
ē
Ω
П
ǐ
ō
ǒ
μ
à
ɡ
ī
²
駿
θ
ū
ì
\ No newline at end of file
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(
_profiler_options['state'], _profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import errno
import os
import pickle
import six
import paddle
from ppocr.utils.logging import get_logger
__all__ = ['load_model']
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def load_model(config, model, optimizer=None, model_type='det'):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
global_config = config['Global']
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if model_type == 'vqa':
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
with open(os.path.join(checkpoints, 'metric.states'),
'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
if checkpoints[-1] in ['/', '\\']:
checkpoints = checkpoints[:-1]
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints)
# load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
continue
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
def load_pretrained_params(model, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path))
return model
def save_model(model,
optimizer,
model_path,
logger,
config,
is_best=False,
prefix='ppocr',
**kwargs):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
else:
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best:
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment