config_qqp.py 2.22 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from omegaconf import OmegaConf

from configs.common.data.bert_dataset import tokenization
from configs.common.models.bert import cfg as qqp_cfg
from configs.common.optim import optim
from configs.common.train import train
from configs.common.models.graph import graph
from libai.config import LazyCall
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader
from projects.QQP.dataset.qqp_dataset import QQPDataset
from projects.QQP.modeling.model import Classification
from projects.QQP.tokenizer.tokenizer import _BertCNWWMTokenizer

tokenization.tokenizer = LazyCall(_BertCNWWMTokenizer)(
    vocab_file="projects/QQP/QQP_DATA/bert-base-chinese-vocab.txt",
    lower_case=True,
)
tokenization.append_eod = False
tokenization.make_vocab_size_divisible_by = 128

dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
    dataset=[
        LazyCall(QQPDataset)(
            dataset_name="QQP_TRAIN",
            data_paths=[
                "projects/QQP/QQP_DATA/train.tsv",
            ],
            max_seq_length=512,
        ),
    ],
    num_workers=4,
)
dataloader.test = [
    LazyCall(build_nlp_test_loader)(
        dataset=LazyCall(QQPDataset)(
            dataset_name="QQP_TEST",
            data_paths=[
                "projects/QQP/QQP_DATA/dev.tsv",
            ],
            max_seq_length=512,
        ),
        num_workers=4,
    ),
]

qqp_cfg.update(
    dict(
        # exist key
        vocab_size=21248,
        hidden_size=1024,
        hidden_layers=24,
        num_attention_heads=16,
        # new key
        num_classes=2,
        pretrain_megatron_weight=None,  # "path/to/model_optim_rng.pt",
    )
)
model = LazyCall(Classification)(cfg=qqp_cfg)

optim.lr = 1e-6
optim.weight_decay = 0.1

train.update(
    dict(
        activation_checkpoint=dict(enabled=True),
        amp=dict(enabled=True),
        output_dir="output/finetune_qqp/",
        train_micro_batch_size=16,
        test_micro_batch_size=4,
        train_epoch=1,
        train_iter=0,
        eval_period=100,
        log_period=10,
        warmup_ratio=0.01,
        topk=(1,),
        dist=dict(
            data_parallel_size=1,
            tensor_parallel_size=1,
            pipeline_parallel_size=1,
        ),
    )
)