"vscode:/vscode.git/clone" did not exist on "7b0dde4d461aac7526d574685dff5352f509a0d0"
Commit 8551c38e authored by silencealiang's avatar silencealiang
Browse files

bug fix

parent bcb9d73e
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import Optional
from functools import wraps
import torch import torch
from torch import Tensor from torch import Tensor
......
...@@ -2,34 +2,8 @@ import os ...@@ -2,34 +2,8 @@ import os
import argparse import argparse
from typing import Union from typing import Union
from megatron.training.arguments import ( from megatron.training.arguments import add_megatron_arguments
_add_network_size_args, from megatron.core.msc_utils import MultiStorageClientFeature
_add_regularization_args,
_add_training_args,
_add_initialization_args,
_add_learning_rate_args,
_add_checkpointing_args,
_add_mixed_precision_args,
_add_distributed_args,
_add_validation_args,
_add_data_args,
_add_tokenizer_args,
_add_autoresume_args,
_add_biencoder_args,
_add_vision_args,
_add_moe_args,
_add_mla_args,
_add_logging_args,
_add_straggler_detector_args,
_add_inference_args,
_add_transformer_engine_args,
_add_retro_args,
_add_experimental_args,
_add_one_logger_args,
_add_ft_package_args,
_add_config_logger_args,
_add_rerun_machine_args,
)
def remove_original_params(parser, param_names: Union[list, str]): def remove_original_params(parser, param_names: Union[list, str]):
...@@ -44,44 +18,24 @@ def remove_original_params(parser, param_names: Union[list, str]): ...@@ -44,44 +18,24 @@ def remove_original_params(parser, param_names: Union[list, str]):
del parser._option_string_actions[option_string] del parser._option_string_actions[option_string]
def parse_args(extra_args_provider=None, ignore_unknown_args=False): def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
"""Parse all arguments.""" parser = add_megatron_arguments(parser)
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments. # add extra arguments
parser = _add_network_size_args(parser)
parser = _add_extra_network_size_args(parser) parser = _add_extra_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_extra_training_args(parser) parser = _add_extra_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_extra_distributed_args(parser) parser = _add_extra_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_tokenizer_args(parser)
parser = _add_extra_tokenizer_args(parser) parser = _add_extra_tokenizer_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser) return parser
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_mla_args(parser) def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_mtp_args(parser) """Parse all arguments."""
parser = _add_logging_args(parser) parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
parser = _add_straggler_detector_args(parser) allow_abbrev=False)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser) parser = add_megatron_arguments_patch(parser)
parser = _add_retro_args(parser)
parser = _add_experimental_args(parser)
parser = _add_one_logger_args(parser)
parser = _add_ft_package_args(parser)
parser = _add_config_logger_args(parser)
parser = _add_rerun_machine_args(parser)
parser = _add_flux_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -101,8 +55,14 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -101,8 +55,14 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
args = load_yaml(args.yaml_cfg) args = load_yaml(args.yaml_cfg)
# Args from environment # Args from environment
#args.rank = int(os.getenv('RANK', '0')) # args.rank = int(os.getenv('RANK', '0'))
#args.world_size = int(os.getenv("WORLD_SIZE", '1')) # args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args return args
...@@ -168,21 +128,6 @@ def _add_extra_tokenizer_args(parser): ...@@ -168,21 +128,6 @@ def _add_extra_tokenizer_args(parser):
return parser return parser
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--mtp-num-layers', type=int, default=None,
help='Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.')
group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
help='Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser
def _add_flux_args(parser): def _add_flux_args(parser):
group = parser.add_argument_group(title='flux args') group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False, group.add_argument('--flux-transpose-weight', action='store_true', default=False,
......
...@@ -10,8 +10,10 @@ from megatron.core.utils import ( ...@@ -10,8 +10,10 @@ from megatron.core.utils import (
StragglerDetector, StragglerDetector,
) )
from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core.distributed import finalize_model_grads from megatron.core.distributed import finalize_model_grads
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.training.initialize import write_args_to_tensorboard from megatron.training.initialize import write_args_to_tensorboard
from megatron.core.num_microbatches_calculator import ( from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size, get_current_global_batch_size,
...@@ -36,6 +38,7 @@ from megatron.training import one_logger_utils ...@@ -36,6 +38,7 @@ from megatron.training import one_logger_utils
from megatron.training import ft_integration from megatron.training import ft_integration
from megatron.training.training import ( from megatron.training.training import (
print_datetime, print_datetime,
should_disable_forward_pre_hook,
disable_forward_pre_hook, disable_forward_pre_hook,
train_step, train_step,
save_checkpoint_and_time, save_checkpoint_and_time,
......
...@@ -8,7 +8,6 @@ do ...@@ -8,7 +8,6 @@ do
fi fi
done done
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
echo $CURRENT_DIR
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR})) MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment