Commit 03bb378f authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix TRT8 core bug

parents a2a12fe4 2e9abcb9
...@@ -32,7 +32,7 @@ import paddle ...@@ -32,7 +32,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -59,7 +59,7 @@ def main(): ...@@ -59,7 +59,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
_ = load_dygraph_params(config, model, logger, None) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -33,7 +33,7 @@ import paddle ...@@ -33,7 +33,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -75,9 +75,7 @@ def main(): ...@@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
] ]
elif config['Architecture']['algorithm'] == "SAR": elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
'image', 'valid_ratio'
]
else: else:
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
......
...@@ -34,11 +34,12 @@ from paddle.jit import to_static ...@@ -34,11 +34,12 @@ from paddle.jit import to_static
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
import cv2 import cv2
def main(config, device, logger, vdl_writer): def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer): ...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer): ...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
...@@ -212,15 +212,15 @@ def train(config, ...@@ -212,15 +212,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
train_dataloader = build_dataloader( train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
batch_start = time.time() total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - reader_start
if idx >= max_iter: if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
...@@ -228,6 +228,7 @@ def train(config, ...@@ -228,6 +228,7 @@ def train(config,
if use_srn: if use_srn:
model_average = True model_average = True
train_start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(): with paddle.amp.auto_cast():
...@@ -252,8 +253,8 @@ def train(config, ...@@ -252,8 +253,8 @@ def train(config,
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_batch_cost += time.time() - batch_start train_run_cost += time.time() - train_start
batch_sum += len(images) total_samples += len(images)
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
lr_scheduler.step() lr_scheduler.step()
...@@ -284,12 +285,13 @@ def train(config, ...@@ -284,12 +285,13 @@ def train(config,
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step, print_batch_step, (train_reader_cost + train_run_cost) /
batch_sum, batch_sum / train_batch_cost) print_batch_step, total_samples,
total_samples / (train_reader_cost + train_run_cost))
logger.info(strs) logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
total_samples = 0
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
...@@ -342,7 +344,7 @@ def train(config, ...@@ -342,7 +344,7 @@ def train(config,
global_step) global_step)
global_step += 1 global_step += 1
optimizer.clear_grad() optimizer.clear_grad()
batch_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
model, model,
...@@ -383,7 +385,11 @@ def eval(model, ...@@ -383,7 +385,11 @@ def eval(model,
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(
total=len(valid_dataloader),
desc='eval model:',
position=0,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system( max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader) ) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
...@@ -452,8 +458,6 @@ def get_center(model, eval_dataloader, post_process_class): ...@@ -452,8 +458,6 @@ def get_center(model, eval_dataloader, post_process_class):
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
#update char_center #update char_center
......
...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss ...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment