Commit 4c942eaf authored by silencealiang's avatar silencealiang
Browse files

bug fix

parent 770fa304
...@@ -24,13 +24,13 @@ class MegatronAdaptation: ...@@ -24,13 +24,13 @@ class MegatronAdaptation:
adaptation.execute() adaptation.execute()
MegatronAdaptation.apply() MegatronAdaptation.apply()
from .patch_utils import MegatronPatchesManager # from .patch_utils import MegatronPatchesManager
args = get_adaptor_args() # args = get_adaptor_args()
for feature in FEATURES_LIST: # for feature in FEATURES_LIST:
if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0: # if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
feature.register_patches(MegatronPatchesManager, args) # feature.register_patches(MegatronPatchesManager, args)
MindSpeedPatchesManager.apply_patches() # MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute() # MegatronAdaptation.post_execute()
...@@ -142,9 +142,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -142,9 +142,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', # MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), # torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
......
import warnings import warnings
from typing import Optional from typing import Optional, Union
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Literal, Optional from typing import Dict, Literal, Optional
...@@ -320,7 +322,7 @@ class GPTModel(LanguageModule): ...@@ -320,7 +322,7 @@ class GPTModel(LanguageModule):
) )
if ( if (
self.num_nextn_predict_layers self.mtp_process is not None
and getattr(self.decoder, "main_final_layernorm", None) is not None and getattr(self.decoder, "main_final_layernorm", None) is not None
): ):
# move block main model final norms here # move block main model final norms here
......
from typing import Optional
from functools import wraps from functools import wraps
from dataclasses import dataclass from dataclasses import dataclass
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import wraps
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
......
...@@ -175,7 +175,7 @@ def _add_mtp_args(parser): ...@@ -175,7 +175,7 @@ def _add_mtp_args(parser):
'MTP extends the prediction scope to multiple future tokens at each position.' 'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens ' 'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.') 'by using D sequential modules to predict D additional tokens.')
group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.1, group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
help='Scaling factor of Multi-Token Prediction (MTP) loss. ' help='Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, ' 'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, ' 'and multiply it the scaling factor to obtain the overall MTP loss, '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment