"deploy/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "83dc7ae9008bcc4deae7bb44660d12de210c58ae"
Commit 8485edfc authored by tink2123's avatar tink2123
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into paddle2onnx

parents 48a6ebad 0b37c118
...@@ -212,15 +212,15 @@ def train(config, ...@@ -212,15 +212,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
train_dataloader = build_dataloader( train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
batch_start = time.time() total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - reader_start
if idx >= max_iter: if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
...@@ -228,6 +228,7 @@ def train(config, ...@@ -228,6 +228,7 @@ def train(config,
if use_srn: if use_srn:
model_average = True model_average = True
train_start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(): with paddle.amp.auto_cast():
...@@ -252,8 +253,8 @@ def train(config, ...@@ -252,8 +253,8 @@ def train(config,
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_batch_cost += time.time() - batch_start train_run_cost += time.time() - train_start
batch_sum += len(images) total_samples += len(images)
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
lr_scheduler.step() lr_scheduler.step()
...@@ -284,12 +285,13 @@ def train(config, ...@@ -284,12 +285,13 @@ def train(config,
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step, print_batch_step, (train_reader_cost + train_run_cost) /
batch_sum, batch_sum / train_batch_cost) print_batch_step, total_samples,
total_samples / (train_reader_cost + train_run_cost))
logger.info(strs) logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
total_samples = 0
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
...@@ -342,7 +344,7 @@ def train(config, ...@@ -342,7 +344,7 @@ def train(config,
global_step) global_step)
global_step += 1 global_step += 1
optimizer.clear_grad() optimizer.clear_grad()
batch_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
model, model,
...@@ -383,7 +385,11 @@ def eval(model, ...@@ -383,7 +385,11 @@ def eval(model,
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(
total=len(valid_dataloader),
desc='eval model:',
position=0,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system( max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader) ) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
...@@ -452,8 +458,6 @@ def get_center(model, eval_dataloader, post_process_class): ...@@ -452,8 +458,6 @@ def get_center(model, eval_dataloader, post_process_class):
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
#update char_center #update char_center
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment