# -*- coding: utf-8 -*- # @Time : 2019/8/23 21:59 # @Author : zhoujun import json import pathlib import time import os import glob from natsort import natsorted import cv2 import matplotlib.pyplot as plt import numpy as np def get_file_list(folder_path: str, p_postfix: list = None, sub_dir: bool = True) -> list: """ 获取所给文件目录里的指定后缀的文件,读取文件列表目前使用的是 os.walk 和 os.listdir ,这两个目前比 pathlib 快很多 :param filder_path: 文件夹名称 :param p_postfix: 文件后缀,如果为 [.*]将返回全部文件 :param sub_dir: 是否搜索子文件夹 :return: 获取到的指定类型的文件列表 """ assert os.path.exists(folder_path) and os.path.isdir(folder_path) if p_postfix is None: p_postfix = ['.jpg'] if isinstance(p_postfix, str): p_postfix = [p_postfix] file_list = [x for x in glob.glob(folder_path + '/**/*.*', recursive=True) if os.path.splitext(x)[-1] in p_postfix or '.*' in p_postfix] return natsorted(file_list) def setup_logger(log_file_path: str = None): import logging logging._warn_preinit_stderr = 0 logger = logging.getLogger('DBNet.pytorch') formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s') ch = logging.StreamHandler() ch.setFormatter(formatter) logger.addHandler(ch) if log_file_path is not None: file_handle = logging.FileHandler(log_file_path) file_handle.setFormatter(formatter) logger.addHandler(file_handle) logger.setLevel(logging.DEBUG) return logger # --exeTime def exe_time(func): def newFunc(*args, **args2): t0 = time.time() back = func(*args, **args2) print("{} cost {:.3f}s".format(func.__name__, time.time() - t0)) return back return newFunc def load(file_path: str): file_path = pathlib.Path(file_path) func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt} assert file_path.suffix in func_dict return func_dict[file_path.suffix](file_path) def _load_txt(file_path: str): with open(file_path, 'r', encoding='utf8') as f: content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] return content def _load_json(file_path: str): with open(file_path, 'r', encoding='utf8') as f: content = json.load(f) return content def save(data, file_path): file_path = pathlib.Path(file_path) func_dict = {'.txt': _save_txt, '.json': _save_json} assert file_path.suffix in func_dict return func_dict[file_path.suffix](data, file_path) def _save_txt(data, file_path): """ 将一个list的数组写入txt文件里 :param data: :param file_path: :return: """ if not isinstance(data, list): data = [data] with open(file_path, mode='w', encoding='utf8') as f: f.write('\n'.join(data)) def _save_json(data, file_path): with open(file_path, 'w', encoding='utf-8') as json_file: json.dump(data, json_file, ensure_ascii=False, indent=4) def show_img(imgs: np.ndarray, title='img'): color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3) imgs = np.expand_dims(imgs, axis=0) for i, img in enumerate(imgs): plt.figure() plt.title('{}_{}'.format(title, i)) plt.imshow(img, cmap=None if color else 'gray') plt.show() def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2): if isinstance(img_path, str): img_path = cv2.imread(img_path) # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) img_path = img_path.copy() for point in result: point = point.astype(int) cv2.polylines(img_path, [point], True, color, thickness) return img_path def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thred=0.5): training_masks = training_masks.data.cpu().numpy() pred_text = texts.data.cpu().numpy() * training_masks pred_text[pred_text <= thred] = 0 pred_text[pred_text > thred] = 1 pred_text = pred_text.astype(np.int32) gt_text = gt_texts.data.cpu().numpy() * training_masks gt_text = gt_text.astype(np.int32) running_metric_text.update(gt_text, pred_text) score_text, _ = running_metric_text.get_scores() return score_text def order_points_clockwise(pts): rect = np.zeros((4, 2), dtype="float32") s = pts.sum(axis=1) rect[0] = pts[np.argmin(s)] rect[2] = pts[np.argmax(s)] diff = np.diff(pts, axis=1) rect[1] = pts[np.argmin(diff)] rect[3] = pts[np.argmax(diff)] return rect def order_points_clockwise_list(pts): pts = pts.tolist() pts.sort(key=lambda x: (x[1], x[0])) pts[:2] = sorted(pts[:2], key=lambda x: x[0]) pts[2:] = sorted(pts[2:], key=lambda x: -x[0]) pts = np.array(pts) return pts def get_datalist(train_data_path): """ 获取训练和验证的数据list :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ :return: """ train_data = [] for p in train_data_path: with open(p, 'r', encoding='utf-8') as f: for line in f.readlines(): line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t') if len(line) > 1: img_path = pathlib.Path(line[0].strip(' ')) label_path = pathlib.Path(line[1].strip(' ')) if img_path.exists() and img_path.stat().st_size > 0 and label_path.exists() and label_path.stat().st_size > 0: train_data.append((str(img_path), str(label_path))) return train_data def parse_config(config: dict) -> dict: import anyconfig base_file_list = config.pop('base') base_config = {} for base_file in base_file_list: tmp_config = anyconfig.load(open(base_file, 'rb')) if 'base' in tmp_config: tmp_config = parse_config(tmp_config) anyconfig.merge(tmp_config, base_config) base_config = tmp_config anyconfig.merge(base_config, config) return base_config def save_result(result_path, box_list, score_list, is_output_polygon): if is_output_polygon: with open(result_path, 'wt') as res: for i, box in enumerate(box_list): box = box.reshape(-1).tolist() result = ",".join([str(int(x)) for x in box]) score = score_list[i] res.write(result + ',' + str(score) + "\n") else: with open(result_path, 'wt') as res: for i, box in enumerate(box_list): score = score_list[i] box = box.reshape(-1).tolist() result = ",".join([str(int(x)) for x in box]) res.write(result + ',' + str(score) + "\n") def expand_polygon(polygon): """ 对只有一个字符的框进行扩充 """ (x, y), (w, h), angle = cv2.minAreaRect(np.float32(polygon)) if angle < -45: w, h = h, w angle += 90 new_w = w + h box = ((x, y), (new_w, h), angle) points = cv2.boxPoints(box) return order_points_clockwise(points) if __name__ == '__main__': img = np.zeros((1, 3, 640, 640)) show_img(img[0][0]) plt.show()