Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
...@@ -5,6 +5,14 @@ def add_modelopt_args(parser): ...@@ -5,6 +5,14 @@ def add_modelopt_args(parser):
"""Add additional arguments for using TensorRT Model Optimizer (modelopt) features.""" """Add additional arguments for using TensorRT Model Optimizer (modelopt) features."""
group = parser.add_argument_group(title="modelopt-generic") group = parser.add_argument_group(title="modelopt-generic")
# Model and Checkpoint Compatibility
group.add_argument(
"--export-model-type",
type=str,
default="GPTModel",
choices=["GPTModel", "MambaModel"],
help="Model type to use in model_provider.",
)
group.add_argument( group.add_argument(
"--export-legacy-megatron", "--export-legacy-megatron",
action="store_true", action="store_true",
...@@ -15,13 +23,34 @@ def add_modelopt_args(parser): ...@@ -15,13 +23,34 @@ def add_modelopt_args(parser):
action="store_true", action="store_true",
help="Export a megatron-core transformer-engine checkpoint.", help="Export a megatron-core transformer-engine checkpoint.",
) )
group.add_argument(
"--export-force-local-attention",
action="store_true",
help="Forcing local DotProductAttention; otherwise TEDotProductAttention is used.",
)
# Quantization
group.add_argument(
"--export-kv-cache-quant",
action="store_true",
help="Whether or not to perform KV-cache quantization.",
)
group.add_argument(
"--export-real-quant-cfg",
type=str,
default="None",
choices=["fp8_real_quant", "fp8_blockwise_real_quant", "None"],
help="Specify a real quantization config from the supported choices.",
)
group.add_argument( group.add_argument(
"--export-quant-cfg", "--export-quant-cfg",
type=str, type=str,
default=None, default=None,
choices=["int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4", "None"], choices=["int8", "int8_sq", "fp8", "fp8_real_quant", "fp8_blockwise", "fp8_blockwise_real_quant", "int4_awq", "w4a8_awq", "int4", "fp4", "None"],
help="Specify a quantization config from the supported choices.", help="Specify a quantization config from the supported choices.",
) )
# Knowledge Distillation
group.add_argument( group.add_argument(
'--export-kd-cfg', '--export-kd-cfg',
type=str, type=str,
...@@ -39,4 +68,27 @@ def add_modelopt_args(parser): ...@@ -39,4 +68,27 @@ def add_modelopt_args(parser):
help='Export original student class back from a loaded distillation model.', help='Export original student class back from a loaded distillation model.',
) )
# Speculative decoding
group.add_argument(
'--export-num-medusa-heads',
type=int,
default=0,
help='Number of Medusa heads for speculative decoding.',
)
group.add_argument(
'--export-num-eagle-layers',
type=int,
default=0,
help='Number of EAGLE layers for speculative decoding.',
)
# Finetuning
group.add_argument(
"--finetune-hf-dataset",
type=str,
default=None,
help="HF dataset used for finetuning."
)
return parser return parser
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn import torch.nn as nn
from megatron.core import dist_checkpointing from megatron.core import dist_checkpointing
...@@ -12,14 +13,52 @@ from megatron.training.checkpointing import _load_base_checkpoint, load_checkpoi ...@@ -12,14 +13,52 @@ from megatron.training.checkpointing import _load_base_checkpoint, load_checkpoi
from megatron.training.utils import print_rank_0, unwrap_model from megatron.training.utils import print_rank_0, unwrap_model
try: try:
import modelopt
from modelopt.torch.opt.plugins import ( from modelopt.torch.opt.plugins import (
get_sharded_modelopt_state, get_sharded_modelopt_state,
restore_modelopt_state_metadata, restore_modelopt_state_metadata,
) )
from modelopt.torch.opt.plugins.mcore_dist_checkpointing import _get_gpt_sharded_modelopt_state
except ImportError as e: except ImportError as e:
raise ImportError("Required `\"nvidia-modelopt[torch]\"` is not installed!") from e raise ImportError("Required `\"nvidia-modelopt[torch]\"` is not installed!") from e
NEMO_WEIGHT_DIR_NAMES = {
"model_weights": "model.",
"weights": "module.",
}
def get_sharded_load_dir(load_dir: str) -> Tuple[str, str]:
"""
"""
sharded_prefix = ""
sharded_load_dir = None
# Read the tracker file and set the iteration if this is a MLM sharded checkpoint.
tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
# If no tracker file, assuming that it is a NeMo sharded checkpoint.
if os.path.isfile(tracker_filename):
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
sharded_load_dir = Path(load_dir) / 'iter_{:07d}'.format(iteration)
except ValueError:
sharded_load_dir = Path(load_dir) / metastring
else:
for nemo_dir_name, prefix in NEMO_WEIGHT_DIR_NAMES.items():
nemo_weight_dir = Path(load_dir) / nemo_dir_name
if os.path.isdir(nemo_weight_dir):
sharded_prefix = prefix
sharded_load_dir = nemo_weight_dir
break
if sharded_load_dir is None:
raise ValueError("{} is not a MLM or NeMo sharded checkpoint!".format(load_dir))
return sharded_load_dir, sharded_prefix
def load_modelopt_state(load_dir: Optional[str] = None, model: Optional[nn.Module] = None) -> Dict: def load_modelopt_state(load_dir: Optional[str] = None, model: Optional[nn.Module] = None) -> Dict:
"""Loading modelopt_state without loading the model. """Loading modelopt_state without loading the model.
...@@ -39,25 +78,23 @@ def load_modelopt_state(load_dir: Optional[str] = None, model: Optional[nn.Modul ...@@ -39,25 +78,23 @@ def load_modelopt_state(load_dir: Optional[str] = None, model: Optional[nn.Modul
if args.use_dist_ckpt: if args.use_dist_ckpt:
assert model is not None, "`model` argument required when `args.use_dist_ckpt is True`" assert model is not None, "`model` argument required when `args.use_dist_ckpt is True`"
# Read the tracker file and set the iteration. sharded_load_dir, _ = get_sharded_load_dir(load_dir)
tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
# If no tracker file, assuming that it is a .nemo checkpoint.
if not os.path.isfile(tracker_filename):
sharded_load_dir = Path(load_dir) / "model_weights"
else:
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
sharded_load_dir = Path(load_dir) / 'iter_{:07d}'.format(iteration)
except ValueError:
sharded_load_dir = Path(load_dir) / metastring
modelopt_state_dir = sharded_load_dir / "modelopt_state" modelopt_state_dir = sharded_load_dir / "modelopt_state"
if modelopt_state_dir.exists(): if modelopt_state_dir.exists():
common_modelopt_state = torch.load(modelopt_state_dir / "common.pt")
extra_kwargs = {}
for mode, mode_cfg in common_modelopt_state["modelopt_state_dict"]:
if mode == "medusa":
extra_kwargs.update({"num_medusa_heads": mode_cfg["config"]["medusa_num_heads"]})
if mode == "eagle" and modelopt.__version__ >= "0.20.0":
print("eagle_mode", mode_cfg["config"])
extra_kwargs.update({"num_eagle_layers": mode_cfg["config"]["eagle_num_layers"]})
print_rank_0("Loading sharded modelopt_state ({})".format(modelopt_state_dir)) print_rank_0("Loading sharded modelopt_state ({})".format(modelopt_state_dir))
modelopt_state = restore_modelopt_state_metadata( modelopt_state = restore_modelopt_state_metadata(
dist_checkpointing.load( dist_checkpointing.load(
get_sharded_modelopt_state(num_layers=args.num_layers, model=model), _get_gpt_sharded_modelopt_state(
num_layers=args.num_layers, **extra_kwargs
),
modelopt_state_dir, modelopt_state_dir,
) )
) )
...@@ -84,7 +121,7 @@ def load_modelopt_checkpoint( ...@@ -84,7 +121,7 @@ def load_modelopt_checkpoint(
optimizer=None, optimizer=None,
opt_param_scheduler=None, opt_param_scheduler=None,
strict: bool = True, strict: bool = True,
additional_sharded_prefix: str = "model.", additional_sharded_prefix: str = "",
load_arg: str = "load", load_arg: str = "load",
) -> None: ) -> None:
"""Load a sharded (untar .nemo or megatron --use-dist-ckpt) or unsharded checkpoint. """Load a sharded (untar .nemo or megatron --use-dist-ckpt) or unsharded checkpoint.
...@@ -120,22 +157,23 @@ def load_modelopt_checkpoint( ...@@ -120,22 +157,23 @@ def load_modelopt_checkpoint(
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
sharded_load_dir, additional_sharded_prefix = get_sharded_load_dir(load_dir)
sharded_load_dir = Path(load_dir) / "model_weights" unwrapped_model = unwrap_model(model)
if sharded_load_dir.exists() and optimizer is None and opt_param_scheduler is None: if args.ckpt_format == "torch":
unwrapped_model = unwrap_model(model) state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
# Set this attribute will alter the sharded_offsets of transformer_block. load_dir, args, rank0=False,
unwrapped_model[0].decoder.config.non_homogeneous_layers = False )
model_state_dict = state_dict["model"]
unwrapped_model[0].load_state_dict(model_state_dict, strict=False)
elif sharded_load_dir.exists() and optimizer is None and opt_param_scheduler is None:
sharded_state_dict = unwrapped_model[0].sharded_state_dict(prefix=additional_sharded_prefix) sharded_state_dict = unwrapped_model[0].sharded_state_dict(prefix=additional_sharded_prefix)
if additional_sharded_prefix: if additional_sharded_prefix:
unwrapped_model[0]._register_load_state_dict_pre_hook( unwrapped_model[0]._register_load_state_dict_pre_hook(
_remove_prefix_state_dict_pre_hook _remove_prefix_state_dict_pre_hook
) )
unwrapped_model[0].load_state_dict( model_state_dict = dist_checkpointing.load(sharded_state_dict, sharded_load_dir)
dist_checkpointing.load(sharded_state_dict, sharded_load_dir) unwrapped_model[0].load_state_dict(model_state_dict, strict=False)
)
# Set the attribute to True such that by-default we are storing the heterogenous arch.
unwrapped_model[0].decoder.config.non_homogeneous_layers = True
else: else:
_ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=strict, load_arg=load_arg) _ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=strict, load_arg=load_arg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment