Commit 5e9fb50d authored by tink2123's avatar tink2123
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into multi_languages

parents 45117f90 5a5d627d
...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators ...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset): ...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset): ...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset): ...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
......
...@@ -213,16 +213,14 @@ class GridGenerator(nn.Layer): ...@@ -213,16 +213,14 @@ class GridGenerator(nn.Layer):
def build_P_paddle(self, I_r_size): def build_P_paddle(self, I_r_size):
I_r_height, I_r_width = I_r_size I_r_height, I_r_width = I_r_size
I_r_grid_x = paddle.divide( I_r_grid_x = (paddle.arange(
paddle.arange( -I_r_width, I_r_width, 2, dtype='float64') + 1.0
-I_r_width, I_r_width, 2, dtype='float64') + 1.0, ) / paddle.to_tensor(np.array([I_r_width]))
paddle.to_tensor(
I_r_width, dtype='float64')) I_r_grid_y = (paddle.arange(
I_r_grid_y = paddle.divide( -I_r_height, I_r_height, 2, dtype='float64') + 1.0
paddle.arange( ) / paddle.to_tensor(np.array([I_r_height]))
-I_r_height, I_r_height, 2, dtype='float64') + 1.0,
paddle.to_tensor(
I_r_height, dtype='float64')) # self.I_r_height
# P: self.I_r_width x self.I_r_height x 2 # P: self.I_r_width x self.I_r_height x 2
P = paddle.stack(paddle.meshgrid(I_r_grid_x, I_r_grid_y), axis=2) P = paddle.stack(paddle.meshgrid(I_r_grid_x, I_r_grid_y), axis=2)
P = paddle.transpose(P, perm=[1, 0, 2]) P = paddle.transpose(P, perm=[1, 0, 2])
......
...@@ -109,7 +109,7 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -109,7 +109,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label)
......
...@@ -182,8 +182,8 @@ def train(config, ...@@ -182,8 +182,8 @@ def train(config,
start_epoch = 1 start_epoch = 1
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
if epoch > 0: train_dataloader = build_dataloader(
train_dataloader = build_dataloader(config, 'Train', device, logger) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0 train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment