Commit 8551c38e authored by silencealiang's avatar silencealiang
Browse files

bug fix

parent bcb9d73e
......@@ -2,6 +2,7 @@ import os
from collections import OrderedDict
from typing import Optional
from functools import wraps
import torch
from torch import Tensor
......
......@@ -2,34 +2,8 @@ import os
import argparse
from typing import Union
from megatron.training.arguments import (
_add_network_size_args,
_add_regularization_args,
_add_training_args,
_add_initialization_args,
_add_learning_rate_args,
_add_checkpointing_args,
_add_mixed_precision_args,
_add_distributed_args,
_add_validation_args,
_add_data_args,
_add_tokenizer_args,
_add_autoresume_args,
_add_biencoder_args,
_add_vision_args,
_add_moe_args,
_add_mla_args,
_add_logging_args,
_add_straggler_detector_args,
_add_inference_args,
_add_transformer_engine_args,
_add_retro_args,
_add_experimental_args,
_add_one_logger_args,
_add_ft_package_args,
_add_config_logger_args,
_add_rerun_machine_args,
)
from megatron.training.arguments import add_megatron_arguments
from megatron.core.msc_utils import MultiStorageClientFeature
def remove_original_params(parser, param_names: Union[list, str]):
......@@ -44,44 +18,24 @@ def remove_original_params(parser, param_names: Union[list, str]):
del parser._option_string_actions[option_string]
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
parser = add_megatron_arguments(parser)
# Standard arguments.
parser = _add_network_size_args(parser)
# add extra arguments
parser = _add_extra_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_extra_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_extra_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_tokenizer_args(parser)
parser = _add_extra_tokenizer_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_mla_args(parser)
parser = _add_mtp_args(parser)
parser = _add_logging_args(parser)
parser = _add_straggler_detector_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_experimental_args(parser)
parser = _add_one_logger_args(parser)
parser = _add_ft_package_args(parser)
parser = _add_config_logger_args(parser)
parser = _add_rerun_machine_args(parser)
parser = _add_flux_args(parser)
return parser
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
parser = add_megatron_arguments_patch(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -101,8 +55,14 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
args = load_yaml(args.yaml_cfg)
# Args from environment
#args.rank = int(os.getenv('RANK', '0'))
#args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args
......@@ -168,21 +128,6 @@ def _add_extra_tokenizer_args(parser):
return parser
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--mtp-num-layers', type=int, default=None,
help='Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.')
group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
help='Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser
def _add_flux_args(parser):
group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
......
......@@ -10,8 +10,10 @@ from megatron.core.utils import (
StragglerDetector,
)
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core.distributed import finalize_model_grads
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.training.initialize import write_args_to_tensorboard
from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size,
......@@ -36,6 +38,7 @@ from megatron.training import one_logger_utils
from megatron.training import ft_integration
from megatron.training.training import (
print_datetime,
should_disable_forward_pre_hook,
disable_forward_pre_hook,
train_step,
save_checkpoint_and_time,
......
......@@ -8,7 +8,6 @@ do
fi
done
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
echo $CURRENT_DIR
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export CUDA_DEVICE_MAX_CONNECTIONS=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment