train_re.py 8.45 KB
Newer Older
WenmuZhou's avatar
add re  
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))

import random
zhoujun's avatar
zhoujun committed
23
import time
WenmuZhou's avatar
add re  
WenmuZhou committed
24
25
26
27
28
29
import numpy as np
import paddle

from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction

from xfun import XFUNDataset
30
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
WenmuZhou's avatar
add re  
WenmuZhou committed
31
from data_collator import DataCollator
zhoujun's avatar
zhoujun committed
32
from eval_re import evaluate
WenmuZhou's avatar
add re  
WenmuZhou committed
33
34
35
36
37
38

from ppocr.utils.logging import get_logger


def train(args):
    logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
WenmuZhou's avatar
WenmuZhou committed
39
40
41
    rank = paddle.distributed.get_rank()
    distributed = paddle.distributed.get_world_size() > 1

WenmuZhou's avatar
add re  
WenmuZhou committed
42
43
44
45
46
47
48
49
50
    print_arguments(args, logger)

    # Added here for reproducibility (even between python 2 and 3)
    set_seed(args.seed)

    label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
    pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index

    # dist mode
WenmuZhou's avatar
WenmuZhou committed
51
    if distributed:
WenmuZhou's avatar
add re  
WenmuZhou committed
52
53
54
        paddle.distributed.init_parallel_env()

    tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
zhoujun's avatar
zhoujun committed
55
56
57
58
59
60
61
62
    if not args.resume:
        model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
        model = LayoutXLMForRelationExtraction(model, dropout=None)
        logger.info('train from scratch')
    else:
        logger.info('resume from {}'.format(args.model_name_or_path))
        model = LayoutXLMForRelationExtraction.from_pretrained(
            args.model_name_or_path)
WenmuZhou's avatar
add re  
WenmuZhou committed
63
64

    # dist mode
WenmuZhou's avatar
WenmuZhou committed
65
66
    if distributed:
        model = paddle.DataParallel(model)
WenmuZhou's avatar
add re  
WenmuZhou committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    train_dataset = XFUNDataset(
        tokenizer,
        data_dir=args.train_data_dir,
        label_path=args.train_label_path,
        label2id_map=label2id_map,
        img_size=(224, 224),
        max_seq_len=args.max_seq_length,
        pad_token_label_id=pad_token_label_id,
        contains_re=True,
        add_special_ids=False,
        return_attention_mask=True,
        load_mode='all')

    eval_dataset = XFUNDataset(
        tokenizer,
        data_dir=args.eval_data_dir,
        label_path=args.eval_label_path,
        label2id_map=label2id_map,
        img_size=(224, 224),
        max_seq_len=args.max_seq_length,
        pad_token_label_id=pad_token_label_id,
        contains_re=True,
        add_special_ids=False,
        return_attention_mask=True,
        load_mode='all')

    train_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
WenmuZhou's avatar
WenmuZhou committed
96

WenmuZhou's avatar
add re  
WenmuZhou committed
97
98
99
    train_dataloader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
WenmuZhou's avatar
WenmuZhou committed
100
        num_workers=args.num_workers,
WenmuZhou's avatar
add re  
WenmuZhou committed
101
102
103
104
105
106
        use_shared_memory=True,
        collate_fn=DataCollator())

    eval_dataloader = paddle.io.DataLoader(
        eval_dataset,
        batch_size=args.per_gpu_eval_batch_size,
WenmuZhou's avatar
WenmuZhou committed
107
        num_workers=args.num_workers,
WenmuZhou's avatar
add re  
WenmuZhou committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        shuffle=False,
        collate_fn=DataCollator())

    t_total = len(train_dataloader) * args.num_train_epochs

    # build linear decay with warmup lr sch
    lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
        learning_rate=args.learning_rate,
        decay_steps=t_total,
        end_lr=0.0,
        power=1.0)
    if args.warmup_steps > 0:
        lr_scheduler = paddle.optimizer.lr.LinearWarmup(
            lr_scheduler,
            args.warmup_steps,
            start_lr=0,
            end_lr=args.learning_rate, )
    grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
    optimizer = paddle.optimizer.Adam(
        learning_rate=args.learning_rate,
        parameters=model.parameters(),
        epsilon=args.adam_epsilon,
        grad_clip=grad_clip,
        weight_decay=args.weight_decay)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = {}".format(len(train_dataset)))
    logger.info("  Num Epochs = {}".format(args.num_train_epochs))
    logger.info("  Instantaneous batch size per GPU = {}".format(
        args.per_gpu_train_batch_size))
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = {}".
WenmuZhou's avatar
WenmuZhou committed
141
142
        format(args.per_gpu_train_batch_size *
               paddle.distributed.get_world_size()))
WenmuZhou's avatar
add re  
WenmuZhou committed
143
144
145
146
147
148
149
150
    logger.info("  Total optimization steps = {}".format(t_total))

    global_step = 0
    model.clear_gradients()
    train_dataloader_len = len(train_dataloader)
    best_metirc = {'f1': 0}
    model.train()

zhoujun's avatar
zhoujun committed
151
152
153
154
155
156
157
    train_reader_cost = 0.0
    train_run_cost = 0.0
    total_samples = 0
    reader_start = time.time()

    print_step = 1

WenmuZhou's avatar
add re  
WenmuZhou committed
158
159
    for epoch in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
zhoujun's avatar
zhoujun committed
160
161
            train_reader_cost += time.time() - reader_start
            train_start = time.time()
WenmuZhou's avatar
add re  
WenmuZhou committed
162
            outputs = model(**batch)
zhoujun's avatar
zhoujun committed
163
            train_run_cost += time.time() - train_start
WenmuZhou's avatar
add re  
WenmuZhou committed
164
165
166
167
168
169
170
171
172
173
            # model outputs are always tuple in ppnlp (see doc)
            loss = outputs['loss']
            loss = loss.mean()

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            # lr_scheduler.step()  # Update learning rate schedule

            global_step += 1
zhoujun's avatar
zhoujun committed
174
175
            total_samples += batch['image'].shape[0]

WenmuZhou's avatar
WenmuZhou committed
176
            if rank == 0 and step % print_step == 0:
zhoujun's avatar
zhoujun committed
177
178
179
180
181
182
183
184
185
186
187
188
189
                logger.info(
                    "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
                    format(epoch, args.num_train_epochs, step,
                           train_dataloader_len, global_step,
                           np.mean(loss.numpy()),
                           optimizer.get_lr(), train_reader_cost / print_step, (
                               train_reader_cost + train_run_cost) / print_step,
                           total_samples / print_step, total_samples / (
                               train_reader_cost + train_run_cost)))

                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
WenmuZhou's avatar
add re  
WenmuZhou committed
190

WenmuZhou's avatar
WenmuZhou committed
191
            if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
WenmuZhou's avatar
add re  
WenmuZhou committed
192
                # Log metrics
WenmuZhou's avatar
WenmuZhou committed
193
194
195
196
197
198
199
200
201
                # Only evaluate when single GPU otherwise metrics may not average well
                results = evaluate(model, eval_dataloader, logger)
                if results['f1'] >= best_metirc['f1']:
                    best_metirc = results
                    output_dir = os.path.join(args.output_dir, "best_model")
                    os.makedirs(output_dir, exist_ok=True)
                    if distributed:
                        model._layers.save_pretrained(output_dir)
                    else:
WenmuZhou's avatar
add re  
WenmuZhou committed
202
203
204
205
206
207
                        model.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    paddle.save(args,
                                os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to {}".format(
                        output_dir))
WenmuZhou's avatar
WenmuZhou committed
208
209
                logger.info("eval results: {}".format(results))
                logger.info("best_metirc: {}".format(best_metirc))
zhoujun's avatar
zhoujun committed
210
            reader_start = time.time()
WenmuZhou's avatar
WenmuZhou committed
211
212
213
214
215
216
217
218
219
220
221
222

        if rank == 0:
            # Save model checkpoint
            output_dir = os.path.join(args.output_dir, "latest_model")
            os.makedirs(output_dir, exist_ok=True)
            if distributed:
                model._layers.save_pretrained(output_dir)
            else:
                model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            paddle.save(args, os.path.join(output_dir, "training_args.bin"))
            logger.info("Saving model checkpoint to {}".format(output_dir))
WenmuZhou's avatar
add re  
WenmuZhou committed
223
224
225
226
227
228
229
    logger.info("best_metirc: {}".format(best_metirc))


if __name__ == "__main__":
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    train(args)