Commit 019b16be authored by chenxj's avatar chenxj
Browse files

first commit

parent 3019db46
Pipeline #569 canceled with stages
import os
from abc import ABC, abstractmethod
class BaseLogger(ABC):
def __init__(self, save_dir):
self.save_dir = save_dir
os.makedirs(self.save_dir, exist_ok=True)
@abstractmethod
def log_metrics(self, metrics, prefix=None):
pass
@abstractmethod
def close(self):
pass
\ No newline at end of file
from .wandb_logger import WandbLogger
class Loggers(object):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def log_metrics(self, metrics, prefix=None, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, prefix=prefix, step=step)
def log_model(self, is_best, prefix, metadata=None):
for logger in self.loggers:
logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
def close(self):
for logger in self.loggers:
logger.close()
\ No newline at end of file
from .base_logger import BaseLogger
from visualdl import LogWriter
class VDLLogger(BaseLogger):
def __init__(self, save_dir):
super().__init__(save_dir)
self.vdl_writer = LogWriter(logdir=save_dir)
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
for k, v in updated_metrics.items():
self.vdl_writer.add_scalar(k, v, step)
def log_model(self, is_best, prefix, metadata=None):
pass
def close(self):
self.vdl_writer.close()
\ No newline at end of file
import os
from .base_logger import BaseLogger
class WandbLogger(BaseLogger):
def __init__(self,
project=None,
name=None,
id=None,
entity=None,
save_dir=None,
config=None,
**kwargs):
try:
import wandb
self.wandb = wandb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install wandb using `pip install wandb`"
)
self.project = project
self.name = name
self.id = id
self.save_dir = save_dir
self.config = config
self.kwargs = kwargs
self.entity = entity
self._run = None
self._wandb_init = dict(
project=self.project,
name=self.name,
id=self.id,
entity=self.entity,
dir=self.save_dir,
resume="allow"
)
self._wandb_init.update(**kwargs)
_ = self.run
if self.config:
self.run.config.update(self.config)
@property
def run(self):
if self._run is None:
if self.wandb.run is not None:
logger.info(
"There is a wandb run already in progress "
"and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()`"
"before instantiating `WandbLogger`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self._wandb_init)
return self._run
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
self.run.log(updated_metrics, step=step)
def log_model(self, is_best, prefix, metadata=None):
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
artifact.add_file(model_path, name="model_ckpt.pdparams")
aliases = [prefix]
if is_best:
aliases.append("best")
self.run.log_artifact(artifact, aliases=aliases)
def close(self):
self.run.finish()
\ No newline at end of file
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
logger.propagate = False
return logger
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
def points2polygon(points):
"""Convert k points to 1 polygon.
Args:
points (ndarray or list): A ndarray or a list of shape (2k)
that indicates k points.
Returns:
polygon (Polygon): A polygon object.
"""
if isinstance(points, list):
points = np.array(points)
assert isinstance(points, np.ndarray)
assert (points.size % 2 == 0) and (points.size >= 8)
point_mat = points.reshape([-1, 2])
return Polygon(point_mat)
def poly_intersection(poly_det, poly_gt, buffer=0.0001):
"""Calculate the intersection area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
intersection_area (float): The intersection area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
if buffer == 0:
poly_inter = poly_det & poly_gt
else:
poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
return poly_inter.area, poly_inter
def poly_union(poly_det, poly_gt):
"""Calculate the union area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
union_area (float): The union area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_det = poly_det.area
area_gt = poly_gt.area
area_inters, _ = poly_intersection(poly_det, poly_gt)
return area_det + area_gt - area_inters
def valid_boundary(x, with_score=True):
num = len(x)
if num < 8:
return False
if num % 2 == 0 and (not with_score):
return True
if num % 2 == 1 and with_score:
return True
return False
def boundary_iou(src, target):
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
Returns:
iou (float): The iou between two boundaries.
"""
assert valid_boundary(src, False)
assert valid_boundary(target, False)
src_poly = points2polygon(src)
target_poly = points2polygon(target)
return poly_iou(src_poly, target_poly)
def poly_iou(poly_det, poly_gt):
"""Calculate the IOU between two polygons.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
iou (float): The IOU between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_inters, _ = poly_intersection(poly_det, poly_gt)
area_union = poly_union(poly_det, poly_gt)
if area_union == 0:
return 0.0
return area_inters / area_union
def poly_nms(polygons, threshold):
assert isinstance(polygons, list)
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
keep_poly = []
index = [i for i in range(polygons.shape[0])]
while len(index) > 0:
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
iou_list = np.zeros((len(index), ))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)
return keep_poly
'
贿
2
0
8
-
7
>
:
]
,
蹿
1
3
!
诿
线
尿
|
;
H
.
/
*
忿
齿
西
5
4
亿
(
访
6
)
稿
s
u
[
9
岿
广
S
Y
F
D
A
P
T
湿
窿
@
丿
沿
使
绿
%
"
婿
r
=
饿
ˇ
q
÷
椿
寿
?
便
殿
J
l
&
驿
x
耀
仿
鸿
廿
z
±
e
t
§
姿
b
<
退
L
鹿
w
i
h
+
I
B
N
^
_
M
鱿
怀
}
~
Z
槿
C
o
E
f
\
屿
U
a
p
y
n
g
竿
Q
羿
O
宿
k
$
c
v
W
穿
×
轿
R
G
ˉ
d
°
K
X
m
涿
`
V
#
簿
{
j
·
Ë
¥
π
é
Λ
Ο
α
 
鴿
è
Ü
И
»
ä
ɔ
´
í
É
ʌ
Я
Й
粿
®
З
β
á
Ó
ò
貿
𣇉
г
楿
滿
Φ
ε
ü
調
ˋ
ā
ú
ó
ē
Ω
П
ǐ
ō
ǒ
μ
à
ɡ
ī
²
駿
θ
ū
ì
\ No newline at end of file
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(
_profiler_options['state'], _profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import errno
import os
import pickle
import six
import paddle
from ppocr.utils.logging import get_logger
__all__ = ['load_model']
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def load_model(config, model, optimizer=None, model_type='det'):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
global_config = config['Global']
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if model_type == 'vqa':
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
with open(os.path.join(checkpoints, 'metric.states'),
'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
if checkpoints[-1] in ['/', '\\']:
checkpoints = checkpoints[:-1]
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints)
# load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
continue
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
def load_pretrained_params(model, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path))
return model
def save_model(model,
optimizer,
model_path,
logger,
config,
is_best=False,
prefix='ppocr',
**kwargs):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
else:
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best:
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import numpy as np
import datetime
__all__ = ['TrainingStats', 'Time']
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size):
self.deque = collections.deque(maxlen=window_size)
def add_value(self, value):
self.deque.append(value)
def get_median_value(self):
return np.median(self.deque)
def Time():
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
class TrainingStats(object):
def __init__(self, window_size, stats_keys):
self.window_size = window_size
self.smoothed_losses_and_metrics = {
key: SmoothedValue(window_size)
for key in stats_keys
}
def update(self, stats):
for k, v in stats.items():
if k not in self.smoothed_losses_and_metrics:
self.smoothed_losses_and_metrics[k] = SmoothedValue(
self.window_size)
self.smoothed_losses_and_metrics[k].add_value(v)
def get(self, extras=None):
stats = collections.OrderedDict()
if extras:
for k, v in extras.items():
stats[k] = v
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = round(v.get_median_value(), 6)
return stats
def log(self, extras=None):
d = self.get(extras)
strs = []
for k, v in d.items():
strs.append('{}: {:x<6f}'.format(k, v))
strs = ', '.join(strs)
return strs
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import imghdr
import cv2
import random
import numpy as np
import paddle
def print_dict(d, logger, delimiter=0):
"""
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", str(k)))
print_dict(v, logger, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", str(k)))
for value in v:
print_dict(value, logger, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
def get_check_global_params(mode):
check_params = ['use_gpu', 'max_text_length', 'image_shape', \
'image_shape', 'character_type', 'loss_type']
if mode == "train_eval":
check_params = check_params + [ \
'train_batch_size_per_card', 'test_batch_size_per_card']
elif mode == "test":
check_params = check_params + ['test_batch_size_per_card']
return check_params
def _check_image_file(path):
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
def check_and_read_gif(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
gif = cv2.VideoCapture(img_path)
ret, frame = gif.read()
if not ret:
logger = logging.getLogger('ppocr')
logger.info("Cannot read {}. This gif image maybe corrupted.")
return None, False
if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
def set_seed(seed=1024):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
"""reset"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""update"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
import tools.program as program
from onnxruntime import InferenceSession
def main():
global_config = config['Global']
# build dataloader
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
if config['Architecture']['Models'][key]['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
pretrained_model = global_config.get('pretrained_model')
print("pretrained_model:", pretrained_model)
model = InferenceSession(pretrained_model, providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])
# build metric
eval_class = build_metric(config['Metric'])
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
if __name__ == '__main__':
config, device, logger, vdl_writer, args = program.preprocess()
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment