arguments.py 7.02 KB
Newer Older
dongcl's avatar
bug fix  
dongcl committed
1
import os
dongcl's avatar
dongcl committed
2
3
import argparse

dongcl's avatar
dongcl committed
4
from typing import Union
dongcl's avatar
dongcl committed
5
6
7
from megatron.training.arguments import (
    _add_network_size_args,
    _add_regularization_args,
dongcl's avatar
dongcl committed
8
    _add_training_args,
dongcl's avatar
dongcl committed
9
10
11
12
    _add_initialization_args,
    _add_learning_rate_args,
    _add_checkpointing_args,
    _add_mixed_precision_args,
dongcl's avatar
dongcl committed
13
    _add_distributed_args,
dongcl's avatar
dongcl committed
14
15
    _add_validation_args,
    _add_data_args,
dongcl's avatar
dongcl committed
16
    _add_tokenizer_args,
dongcl's avatar
dongcl committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    _add_autoresume_args,
    _add_biencoder_args,
    _add_vision_args,
    _add_moe_args,
    _add_mla_args,
    _add_logging_args,
    _add_straggler_detector_args,
    _add_inference_args,
    _add_transformer_engine_args,
    _add_retro_args,
    _add_experimental_args,
    _add_one_logger_args,
    _add_ft_package_args,
    _add_config_logger_args,
    _add_rerun_machine_args,
)


dongcl's avatar
dongcl committed
35
36
37
38
39
40
41
42
43
44
45
46
def remove_original_params(parser, param_names: Union[list, str]):
    if isinstance(param_names, str):
        param_names = [param_names]

    for action in parser._actions:
        if action.dest in param_names:
            parser._actions.remove(action)
            for option_string in action.option_strings:
                if option_string in parser._option_string_actions:
                    del parser._option_string_actions[option_string]


dongcl's avatar
dongcl committed
47
48
49
50
51
52
53
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
    """Parse all arguments."""
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)

    # Standard arguments.
    parser = _add_network_size_args(parser)
dongcl's avatar
dongcl committed
54
    parser = _add_extra_network_size_args(parser)
dongcl's avatar
dongcl committed
55
56
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
dongcl's avatar
dongcl committed
57
    parser = _add_extra_training_args(parser)
dongcl's avatar
dongcl committed
58
59
60
61
62
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
dongcl's avatar
dongcl committed
63
    parser = _add_extra_distributed_args(parser)
dongcl's avatar
dongcl committed
64
65
66
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_tokenizer_args(parser)
dongcl's avatar
dongcl committed
67
    parser = _add_extra_tokenizer_args(parser)
dongcl's avatar
dongcl committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    parser = _add_autoresume_args(parser)
    parser = _add_biencoder_args(parser)
    parser = _add_vision_args(parser)
    parser = _add_moe_args(parser)
    parser = _add_mla_args(parser)
    parser = _add_mtp_args(parser)
    parser = _add_logging_args(parser)
    parser = _add_straggler_detector_args(parser)
    parser = _add_inference_args(parser)
    parser = _add_transformer_engine_args(parser)
    parser = _add_retro_args(parser)
    parser = _add_experimental_args(parser)
    parser = _add_one_logger_args(parser)
    parser = _add_ft_package_args(parser)
    parser = _add_config_logger_args(parser)
    parser = _add_rerun_machine_args(parser)
dongcl's avatar
dongcl committed
84
    parser = _add_flux_args(parser)
dongcl's avatar
dongcl committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)

    # Parse.
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()

    # Experimental yaml
    if args.yaml_cfg is not None:
        from megatron.training.yaml_arguments import load_yaml
        assert args.yaml_cfg and not args.use_legacy_models, \
            "Yaml config is not supported with legacy models."
        args = load_yaml(args.yaml_cfg)

    # Args from environment
    #args.rank = int(os.getenv('RANK', '0'))
    #args.world_size = int(os.getenv("WORLD_SIZE", '1'))

    return args


dongcl's avatar
dongcl committed
110
111
112
113
114
115
116
117
118
119
120
121
def _add_extra_network_size_args(parser):
    # 删除原参数
    remove_original_params(parser, ["normalization"])

    # 重定义参数
    group = parser.add_argument_group(title='extra network size args')
    group.add_argument('--normalization', default='LayerNorm',
                       choices=['LayerNorm', 'RMSNorm', 'LightopRMSNorm'],
                       help='Which normalization technique to use.')
    return parser


dongcl's avatar
dongcl committed
122
123
def _add_extra_distributed_args(parser):
    group = parser.add_argument_group(title='extra distributed args')
dongcl's avatar
dongcl committed
124
125
126
127
128
129
130
131
132
    group.add_argument('--rank', default=-1, type=int,
                       help='node rank for distributed training')
    group.add_argument('--world-size', type=int, default=8,
                       help='number of nodes for distributed training')
    group.add_argument('--dist-url',
                       help='Which master node url for distributed training.')
    return parser


dongcl's avatar
dongcl committed
133
134
135
136
137
138
139
140
141
142
143
def _add_extra_training_args(parser):
    group = parser.add_argument_group(title='extra training args')
    group.add_argument('--use-hip-profiler', action='store_true',
                       help='Use HIP PROFILER',
                       dest='use_hip_profiler')
    group.add_argument('--profile-dir', type=str, default="./",
                       help='profile dir to save.')

    return parser


dongcl's avatar
dongcl committed
144
145
def _add_extra_tokenizer_args(parser):
    # 删除原参数
dongcl's avatar
dongcl committed
146
    remove_original_params(parser, ["tokenizer_type"])
dongcl's avatar
dongcl committed
147
148
149

    # 重定义参数
    group = parser.add_argument_group(title='extra tokenizer args')
dongcl's avatar
dongcl committed
150
151
    group.add_argument('--extra-vocab-size', type=int, default=0,
                       help="--extra-vocab-size")
dongcl's avatar
dongcl committed
152
153
154
155
156
157
158
159
160
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
                                'BertWordPieceCase',
                                'GPT2BPETokenizer',
                                'SentencePieceTokenizer',
                                'GPTSentencePieceTokenizer',
                                'HuggingFaceTokenizer',
                                'Llama2Tokenizer',
dongcl's avatar
dongcl committed
161
162
                                'Llama3Tokenizer',
                                'QwenTokenizer',
dongcl's avatar
dongcl committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
                                'TikTokenizer',
                                'MultimodalTokenizer',
                                'NullTokenizer',
                                'DeepSeekV2Tokenizer'],
                       help='What type of tokenizer to use.')
    return parser


def _add_mtp_args(parser):
    group = parser.add_argument_group(title='multi token prediction')
    group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
    group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale')
    group.add_argument('--recompute-mtp-norm', action='store_true', default=False,
                       help='Multi-Token prediction recompute norm')
    group.add_argument('--recompute-mtp-layer', action='store_true', default=False,
                       help='Multi-Token prediction recompute layer')
    group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
                       help='Main model share embedding and output weight with mtp layer.')
dongcl's avatar
dongcl committed
181
    return parser
dongcl's avatar
dongcl committed
182
183
184


def _add_flux_args(parser):
dongcl's avatar
dongcl committed
185
    group = parser.add_argument_group(title='flux args')
dongcl's avatar
dongcl committed
186
187
188
    group.add_argument('--flux-transpose-weight', action='store_true', default=False,
                       help='Whether to transpose weight when using flux kernel')
    return parser