Commit 0cd2527c authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into update_requirements

parents df05d1fd 479b7672
doc/imgs_results/whl/12_det.jpg

166 KB | W: | H:

doc/imgs_results/whl/12_det.jpg

410 KB | W: | H:

doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
  • 2-up
  • Swipe
  • Onion skin
doc/joinus.PNG

111 KB | W: | H:

doc/joinus.PNG

109 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
This diff is collapsed.
...@@ -34,6 +34,7 @@ import paddle.distributed as dist ...@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else: else:
use_shared_memory = True use_shared_memory = True
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card # Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
......
...@@ -28,6 +28,7 @@ from .label_ops import * ...@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import *
def transform(data, ops=None): def transform(data, ops=None):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -117,13 +117,16 @@ class RawRandAugment(object): ...@@ -117,13 +117,16 @@ class RawRandAugment(object):
class RandAugment(RawRandAugment): class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """ """ RandAugment wrapper to auto fit different img types """
def __init__(self, *args, **kwargs): def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
if six.PY2: if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs) super(RandAugment, self).__init__(*args, **kwargs)
else: else:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __call__(self, data): def __call__(self, data):
if np.random.rand() > self.prob:
return data
img = data['image'] img = data['image']
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
......
This diff is collapsed.
This diff is collapsed.
...@@ -29,10 +29,11 @@ def build_loss(config): ...@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss' 'SRNLoss', 'PGLoss']
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): ...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
i, :, :], ohem_ratio)) i, :, :], ohem_ratio))
selected_masks = np.concatenate(selected_masks, 0) selected_masks = np.concatenate(selected_masks, 0)
selected_masks = paddle.to_variable(selected_masks) selected_masks = paddle.to_tensor(selected_masks)
return selected_masks return selected_masks
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment