arguments.py 4.17 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import colossalai
from numpy import require

__all__ = ['parse_args']


def parse_args():
    parser = colossalai.get_default_parser()
    
    parser.add_argument(
        '--lr', 
        type=float, 
        required=True,
        help='initial learning rate')
    parser.add_argument(
        '--epoch', 
        type=int, 
        required=True,
        help='number of epoch')
    parser.add_argument(
        '--data_path_prefix', 
        type=str, 
        required=True,
        help="location of the train data corpus")
    parser.add_argument(
        '--eval_data_path_prefix', 
        type=str, 
        required=True,
        help='location of the evaluation data corpus')
    parser.add_argument(
        '--tokenizer_path', 
        type=str, 
        required=True,
        help='location of the tokenizer')
    parser.add_argument(
        '--max_seq_length', 
        type=int, 
        default=512,
        help='sequence length')
    parser.add_argument(
        '--refresh_bucket_size',
        type=int,
        default=1,
        help=
        "This param makes sure that a certain task is repeated for this time steps to \
        optimise on the back propogation speed with APEX's DistributedDataParallel")
    parser.add_argument(
        "--max_predictions_per_seq",
        "--max_pred",
        default=80,
        type=int,
        help=
        "The maximum number of masked tokens in a sequence to be predicted.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        default=1,
        type=int,
        help="accumulation_steps")
    parser.add_argument(
        "--train_micro_batch_size_per_gpu",
        default=2,
        type=int,
        required=True,
        help="train batch size")
    parser.add_argument(
        "--eval_micro_batch_size_per_gpu",
        default=2,
        type=int,
        required=True,
        help="eval batch size")
    parser.add_argument(
        "--num_workers",
        default=8,
        type=int,
        help="")
    parser.add_argument(
        "--async_worker",
        action='store_true',
        help="")
    parser.add_argument(
        "--bert_config",
        required=True,
        type=str,
        help="location of config.json")
    parser.add_argument(
        "--wandb",
        action='store_true',
        help="use wandb to watch model")
    parser.add_argument(
        "--wandb_project_name",
        default='roberta',
        help="wandb project name")
    parser.add_argument(
        "--log_interval",
        default=100,
        type=int,
        help="report interval")
    parser.add_argument(
        "--log_path",
        type=str,
        required=True,
        help="log file which records train step")
    parser.add_argument(
        "--tensorboard_path",
        type=str,
        required=True,
        help="location of tensorboard file")
    parser.add_argument(
        "--colossal_config",
        type=str,
        required=True,
        help="colossal config, which contains zero config and so on")
    parser.add_argument(
        "--ckpt_path",
        type=str,
        required=True,
        help="location of saving checkpoint, which contains model and optimizer")
    parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help="random seed for initialization")
    parser.add_argument(
        '--vscode_debug',
        action='store_true',
        help="use vscode to debug")
    parser.add_argument(
        '--load_pretrain_model',
        default='',
        type=str,
        help="location of model's checkpoin")
    parser.add_argument(
        '--load_optimizer_lr',
        default='',
        type=str,
        help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
    parser.add_argument(
        '--resume_train',
        action='store_true',
        help="whether resume training from a early checkpoint")
    parser.add_argument(
        '--mlm',
        default='bert',
        type=str,
        help="model type, bert or deberta")
    parser.add_argument(
        '--checkpoint_activations',
        action='store_true',
        help="whether to use gradient checkpointing")

    args = parser.parse_args()
    return args