Commit 95f13c48 authored by dongcl's avatar dongcl
Browse files

support for removing args

parent 0b492884
import os
import argparse
from typing import Union
from megatron.training.arguments import (
_add_network_size_args,
_add_regularization_args,
......@@ -31,6 +32,18 @@ from megatron.training.arguments import (
)
def remove_original_params(parser, param_names: Union[list, str]):
if isinstance(param_names, str):
param_names = [param_names]
for action in parser._actions:
if action.dest in param_names:
parser._actions.remove(action)
for option_string in action.option_strings:
if option_string in parser._option_string_actions:
del parser._option_string_actions[option_string]
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
......@@ -103,14 +116,25 @@ def _add_extra_distributed_args(parser):
return parser
def _add_extra_training_args(parser):
group = parser.add_argument_group(title='extra training args')
group.add_argument('--use-hip-profiler', action='store_true',
help='Use HIP PROFILER',
dest='use_hip_profiler')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
return parser
def _add_extra_tokenizer_args(parser):
# 删除原参数
for action in parser._actions:
if action.dest == 'tokenizer_type':
parser._actions.remove(action)
remove_original_params(parser, ["tokenizer_type"])
# 重定义参数
group = parser.add_argument_group(title='extra tokenizer args')
group.add_argument('--extra-vocab-size', type=int, default=0,
help="--extra-vocab-size")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
......@@ -130,17 +154,6 @@ def _add_extra_tokenizer_args(parser):
return parser
def _add_extra_training_args(parser):
group = parser.add_argument_group(title='extra training args')
group.add_argument('--use-hip-profiler', action='store_true',
help='Use HIP PROFILER',
dest='use_hip_profiler')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
return parser
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment