program.py 22.1 KB
Newer Older
MissPenguin's avatar
refine  
MissPenguin committed
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
LDOUBLEV's avatar
LDOUBLEV committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

WenmuZhou's avatar
WenmuZhou committed
19
import os
LDOUBLEV's avatar
LDOUBLEV committed
20
import sys
21
import platform
LDOUBLEV's avatar
LDOUBLEV committed
22
23
import yaml
import time
24
import datetime
WenmuZhou's avatar
WenmuZhou committed
25
26
27
28
29
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

LDOUBLEV's avatar
LDOUBLEV committed
30
31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
32
from ppocr.utils.utility import print_dict, AverageMeter
dyning's avatar
dyning committed
33
from ppocr.utils.logging import get_logger
34
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
LDOUBLEV's avatar
LDOUBLEV committed
35
from ppocr.utils import profiler
dyning's avatar
dyning committed
36
from ppocr.data import build_dataloader
LDOUBLEV's avatar
LDOUBLEV committed
37

dyning's avatar
dyning committed
38

LDOUBLEV's avatar
LDOUBLEV committed
39
40
41
42
43
44
45
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
LDOUBLEV's avatar
LDOUBLEV committed
46
47
48
49
50
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
51
52
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
LDOUBLEV's avatar
LDOUBLEV committed
53
        )
LDOUBLEV's avatar
LDOUBLEV committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
82
83
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
LDOUBLEV's avatar
LDOUBLEV committed
84
85


86
def merge_config(config, opts):
LDOUBLEV's avatar
LDOUBLEV committed
87
88
89
90
91
92
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
93
    for key, value in opts.items():
LDOUBLEV's avatar
LDOUBLEV committed
94
        if "." not in key:
95
96
            if isinstance(value, dict) and key in config:
                config[key].update(value)
LDOUBLEV's avatar
LDOUBLEV committed
97
            else:
98
                config[key] = value
LDOUBLEV's avatar
LDOUBLEV committed
99
100
        else:
            sub_keys = key.split('.')
tink2123's avatar
tink2123 committed
101
            assert (
102
                sub_keys[0] in config
103
104
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
105
106
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
LDOUBLEV's avatar
LDOUBLEV committed
107
108
109
110
111
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
112
    return config
LDOUBLEV's avatar
LDOUBLEV committed
113
114


xiaoting's avatar
xiaoting committed
115
def check_device(use_gpu, use_xpu=False):
LDOUBLEV's avatar
LDOUBLEV committed
116
117
118
119
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
xiaoting's avatar
xiaoting committed
120
121
122
123
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
LDOUBLEV's avatar
LDOUBLEV committed
124
125
126
          "model on CPU"

    try:
xiaoting's avatar
xiaoting committed
127
128
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
WenmuZhou's avatar
WenmuZhou committed
129
        if use_gpu and not paddle.is_compiled_with_cuda():
xiaoting's avatar
xiaoting committed
130
131
132
133
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
LDOUBLEV's avatar
LDOUBLEV committed
134
135
136
137
138
            sys.exit(1)
    except Exception as e:
        pass


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def check_xpu(use_xpu):
    """
    Log error and exit when set use_xpu=true in paddlepaddle
    cpu/gpu version.
    """
    err = "Config use_xpu cannot be set as true while you are " \
          "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
          "\t2. Set use_xpu as false in config file to run " \
          "model on CPU/GPU"

    try:
        if use_xpu and not paddle.is_compiled_with_xpu():
            print(err)
            sys.exit(1)
    except Exception as e:
        pass


WenmuZhou's avatar
WenmuZhou committed
158
def train(config,
dyning's avatar
dyning committed
159
160
161
          train_dataloader,
          valid_dataloader,
          device,
WenmuZhou's avatar
WenmuZhou committed
162
163
164
165
166
167
168
169
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
170
          log_writer=None,
stephon's avatar
stephon committed
171
          scaler=None):
WenmuZhou's avatar
WenmuZhou committed
172
173
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
174
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
LDOUBLEV's avatar
LDOUBLEV committed
175
176
177
178
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
LDOUBLEV's avatar
LDOUBLEV committed
179
    profiler_options = config['profiler_options']
WenmuZhou's avatar
WenmuZhou committed
180

dyning's avatar
dyning committed
181
    global_step = 0
182
183
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
LDOUBLEV's avatar
LDOUBLEV committed
184
185
186
187
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
WenmuZhou's avatar
WenmuZhou committed
188
189
        if len(valid_dataloader) == 0:
            logger.info(
190
191
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
WenmuZhou's avatar
WenmuZhou committed
192
193
            )
            start_eval_step = 1e111
LDOUBLEV's avatar
LDOUBLEV committed
194
        logger.info(
195
196
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
LDOUBLEV's avatar
LDOUBLEV committed
197
            format(start_eval_step, eval_batch_step))
LDOUBLEV's avatar
LDOUBLEV committed
198
199
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
200
201
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
WenmuZhou's avatar
WenmuZhou committed
202
203
204
205
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
tink2123's avatar
tink2123 committed
206
    model_average = False
WenmuZhou's avatar
WenmuZhou committed
207
208
    model.train()

tink2123's avatar
tink2123 committed
209
    use_srn = config['Architecture']['algorithm'] == "SRN"
andyjpaddle's avatar
andyjpaddle committed
210
    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
andyjpaddle's avatar
andyjpaddle committed
211
    extra_input = False
andyjpaddle's avatar
andyjpaddle committed
212
    if config['Architecture']['algorithm'] == 'Distillation':
andyjpaddle's avatar
andyjpaddle committed
213
214
215
        for key in config['Architecture']["Models"]:
            extra_input = extra_input or config['Architecture']['Models'][key][
                'algorithm'] in extra_input_models
andyjpaddle's avatar
andyjpaddle committed
216
217
    else:
        extra_input = config['Architecture']['algorithm'] in extra_input_models
218
    try:
LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
219
        model_type = config['Architecture']['model_type']
220
    except:
LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
221
        model_type = None
andyjpaddle's avatar
andyjpaddle committed
222

tink2123's avatar
tink2123 committed
223
    algorithm = config['Architecture']['algorithm']
tink2123's avatar
tink2123 committed
224

225
226
227
228
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
229
230
    train_reader_cost = 0.0
    train_batch_cost = 0.0
231
    reader_start = time.time()
232
    eta_meter = AverageMeter()
233
234
235

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
WenmuZhou's avatar
WenmuZhou committed
236

tink2123's avatar
tink2123 committed
237
    for epoch in range(start_epoch, epoch_num + 1):
238
239
240
241
242
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
WenmuZhou's avatar
WenmuZhou committed
243
        for idx, batch in enumerate(train_dataloader):
LDOUBLEV's avatar
LDOUBLEV committed
244
            profiler.add_profiler_step(profiler_options)
WenmuZhou's avatar
WenmuZhou committed
245
            train_reader_cost += time.time() - reader_start
Jane-Ding's avatar
Jane-Ding committed
246
            if idx >= max_iter:
WenmuZhou's avatar
WenmuZhou committed
247
248
249
                break
            lr = optimizer.get_lr()
            images = batch[0]
tink2123's avatar
tink2123 committed
250
            if use_srn:
tink2123's avatar
tink2123 committed
251
                model_average = True
stephon's avatar
stephon committed
252
253
254
255
256
257
258
259

            # use amp
            if scaler:
                with paddle.amp.auto_cast():
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
                    else:
                        preds = model(images)
tink2123's avatar
tink2123 committed
260
            else:
stephon's avatar
stephon committed
261
262
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
263
                elif model_type in ["kie", 'vqa']:
LDOUBLEV's avatar
LDOUBLEV committed
264
                    preds = model(batch)
stephon's avatar
stephon committed
265
266
                else:
                    preds = model(images)
267

WenmuZhou's avatar
WenmuZhou committed
268
269
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
stephon's avatar
stephon committed
270
271
272
273
274
275
276
277

            if scaler:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
            else:
                avg_loss.backward()
                optimizer.step()
WenmuZhou's avatar
WenmuZhou committed
278
            optimizer.clear_grad()
WenmuZhou's avatar
WenmuZhou committed
279

280
281
282
283
284
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
                if model_type in ['table', 'kie']:
                    eval_class(preds, batch)
                else:
andyjpaddle's avatar
andyjpaddle committed
285
286
287
288
289
290
                    if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
                                                  ]:  # for multi head loss
                        post_result = post_process_class(
                            preds['ctc'], batch[1])  # for CTC head out
                    else:
                        post_result = post_process_class(preds, batch[1])
291
292
293
294
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

295
296
297
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
298
            global_step += 1
WenmuZhou's avatar
WenmuZhou committed
299
            total_samples += len(images)
WenmuZhou's avatar
WenmuZhou committed
300

dyning's avatar
dyning committed
301
302
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
WenmuZhou's avatar
WenmuZhou committed
303
304
305
306
307
308

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

xiaoting's avatar
xiaoting committed
309

310
311
            if log_writer is not None and dist.get_rank() == 0:
                log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
WenmuZhou's avatar
WenmuZhou committed
312

313
314
315
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
WenmuZhou's avatar
WenmuZhou committed
316
                logs = train_stats.log()
LDOUBLEV's avatar
LDOUBLEV committed
317

318
319
320
321
322
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
                       '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
LDOUBLEV's avatar
LDOUBLEV committed
323
                       'ips: {:.5f} samples/s, eta: {}'.format(
324
325
326
327
328
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
WenmuZhou's avatar
WenmuZhou committed
329
                logger.info(strs)
330

WenmuZhou's avatar
WenmuZhou committed
331
                total_samples = 0
332
333
                train_reader_cost = 0.0
                train_batch_cost = 0.0
WenmuZhou's avatar
WenmuZhou committed
334
335
            # eval
            if global_step > start_eval_step and \
336
337
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
tink2123's avatar
tink2123 committed
338
339
340
341
342
343
344
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
tink2123's avatar
tink2123 committed
345
346
347
348
349
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
MissPenguin's avatar
refine  
MissPenguin committed
350
                    model_type,
tink2123's avatar
tink2123 committed
351
                    extra_input=extra_input)
LDOUBLEV's avatar
LDOUBLEV committed
352
353
354
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
WenmuZhou's avatar
WenmuZhou committed
355
356

                # logger metric
357
358
359
                if log_writer is not None:
                    log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)

LDOUBLEV's avatar
LDOUBLEV committed
360
                if cur_metric[main_indicator] >= best_model_dict[
WenmuZhou's avatar
WenmuZhou committed
361
                        main_indicator]:
LDOUBLEV's avatar
LDOUBLEV committed
362
                    best_model_dict.update(cur_metric)
WenmuZhou's avatar
WenmuZhou committed
363
364
365
366
367
368
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
369
                        config,
WenmuZhou's avatar
WenmuZhou committed
370
371
372
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
373
374
                        epoch=epoch,
                        global_step=global_step)
LDOUBLEV's avatar
LDOUBLEV committed
375
                best_str = 'best metric, {}'.format(', '.join([
WenmuZhou's avatar
WenmuZhou committed
376
377
378
379
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
380
381
382
                if log_writer is not None:
                    log_writer.log_metrics(metrics={
                        "best_{}".format(main_indicator): best_model_dict[main_indicator]
383
384
385
                        }, prefix="EVAL", step=global_step)
                    
                    log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
386

WenmuZhou's avatar
WenmuZhou committed
387
            reader_start = time.time()
WenmuZhou's avatar
WenmuZhou committed
388
389
390
391
392
393
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
394
                config,
WenmuZhou's avatar
WenmuZhou committed
395
396
397
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
398
399
                epoch=epoch,
                global_step=global_step)
400

401
402
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix="latest")
403

WenmuZhou's avatar
WenmuZhou committed
404
405
406
407
408
409
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
410
                config,
WenmuZhou's avatar
WenmuZhou committed
411
412
413
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
414
415
                epoch=epoch,
                global_step=global_step)
416
417
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
418

LDOUBLEV's avatar
LDOUBLEV committed
419
    best_str = 'best metric, {}'.format(', '.join(
WenmuZhou's avatar
WenmuZhou committed
420
421
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
422
423
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
LDOUBLEV's avatar
LDOUBLEV committed
424
425
426
    return


MissPenguin's avatar
refine  
MissPenguin committed
427
428
429
430
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
LDOUBLEV's avatar
LDOUBLEV committed
431
         model_type=None,
tink2123's avatar
tink2123 committed
432
         extra_input=False):
WenmuZhou's avatar
WenmuZhou committed
433
434
435
436
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
WenmuZhou's avatar
WenmuZhou committed
437
438
439
440
441
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
442
443
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
WenmuZhou's avatar
WenmuZhou committed
444
        for idx, batch in enumerate(valid_dataloader):
445
            if idx >= max_iter:
WenmuZhou's avatar
WenmuZhou committed
446
                break
WenmuZhou's avatar
fix bug  
WenmuZhou committed
447
            images = batch[0]
WenmuZhou's avatar
WenmuZhou committed
448
            start = time.time()
tink2123's avatar
tink2123 committed
449
            if model_type == 'table' or extra_input:
MissPenguin's avatar
refine  
MissPenguin committed
450
                preds = model(images, data=batch[1:])
451
            elif model_type in ["kie", 'vqa']:
LDOUBLEV's avatar
LDOUBLEV committed
452
                preds = model(batch)
xiaoting's avatar
xiaoting committed
453
            else:
LDOUBLEV's avatar
LDOUBLEV committed
454
                preds = model(images)
455
456
457
458
459
460
461

            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
WenmuZhou's avatar
WenmuZhou committed
462
463
464
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
LDOUBLEV's avatar
LDOUBLEV committed
465
            if model_type in ['table', 'kie']:
466
467
468
469
                eval_class(preds, batch_numpy)
            elif model_type in ['vqa']:
                post_result = post_process_class(preds, batch_numpy)
                eval_class(post_result, batch_numpy)
MissPenguin's avatar
MissPenguin committed
470
            else:
471
472
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
LDOUBLEV's avatar
LDOUBLEV committed
473

WenmuZhou's avatar
fix bug  
WenmuZhou committed
474
            pbar.update(1)
WenmuZhou's avatar
WenmuZhou committed
475
            total_frame += len(images)
LDOUBLEV's avatar
LDOUBLEV committed
476
477
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
dyning's avatar
dyning committed
478

WenmuZhou's avatar
fix bug  
WenmuZhou committed
479
    pbar.close()
WenmuZhou's avatar
WenmuZhou committed
480
    model.train()
LDOUBLEV's avatar
LDOUBLEV committed
481
482
    metric['fps'] = total_frame / total_time
    return metric
licx's avatar
licx committed
483

tink2123's avatar
tink2123 committed
484

Bin Lu's avatar
Bin Lu committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


534
def preprocess(is_train=False):
licx's avatar
licx committed
535
    FLAGS = ArgsParser().parse_args()
LDOUBLEV's avatar
LDOUBLEV committed
536
    profiler_options = FLAGS.profiler_options
licx's avatar
licx committed
537
    config = load_config(FLAGS.config)
538
    config = merge_config(config, FLAGS.opt)
LDOUBLEV's avatar
LDOUBLEV committed
539
    profile_dic = {"profiler_options": FLAGS.profiler_options}
540
    config = merge_config(config, profile_dic)
licx's avatar
licx committed
541

542
543
544
545
546
547
548
549
550
551
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
zhoujun's avatar
zhoujun committed
552
    logger = get_logger(log_file=log_file)
licx's avatar
licx committed
553
554
555

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
xiaoting's avatar
xiaoting committed
556
    use_xpu = config['Global'].get('use_xpu', False)
licx's avatar
licx committed
557

558
559
560
561
562
563
    # check if set use_xpu=True in paddlepaddle cpu/gpu version
    use_xpu = False
    if 'use_xpu' in config['Global']:
        use_xpu = config['Global']['use_xpu']
    check_xpu(use_xpu)

WenmuZhou's avatar
WenmuZhou committed
564
565
    alg = config['Architecture']['algorithm']
    assert alg in [
Jethong's avatar
Jethong committed
566
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
tink2123's avatar
tink2123 committed
567
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
andyjpaddle's avatar
andyjpaddle committed
568
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
WenmuZhou's avatar
WenmuZhou committed
569
    ]
licx's avatar
licx committed
570

571
    if use_xpu:
xiaoting's avatar
xiaoting committed
572
573
574
575
576
577
        device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
    else:
        device = 'gpu:{}'.format(dist.ParallelEnv()
                                 .dev_id) if use_gpu else 'cpu'
    check_device(use_gpu, use_xpu)

WenmuZhou's avatar
WenmuZhou committed
578
    device = paddle.set_device(device)
dyning's avatar
dyning committed
579

dyning's avatar
dyning committed
580
    config['Global']['distributed'] = dist.get_world_size() != 1
581

582
583
    loggers = []

584
    if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
585
        save_model_dir = config['Global']['save_model_dir']
dyning's avatar
dyning committed
586
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
587
        log_writer = VDLLogger(save_model_dir)
588
        loggers.append(log_writer)
589
    if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
590
591
592
593
594
595
596
597
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
598
        loggers.append(log_writer)
dyning's avatar
dyning committed
599
    else:
600
        log_writer = None
dyning's avatar
dyning committed
601
    print_dict(config, logger)
602
603
604
605
606
607

    if loggers:
        log_writer = Loggers(loggers)
    else:
        log_writer = None

dyning's avatar
dyning committed
608
609
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
610
    return config, device, logger, log_writer