Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
5890bb4c
Commit
5890bb4c
authored
May 14, 2025
by
dongcl
Browse files
fix import error
parent
4bb958ec
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
119 additions
and
89 deletions
+119
-89
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+28
-13
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+7
-7
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+5
-5
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+20
-35
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+1
-1
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+1
-1
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+16
-8
dcu_megatron/core/transformer/mlp.py
dcu_megatron/core/transformer/mlp.py
+1
-3
dcu_megatron/core/transformer/moe/experts.py
dcu_megatron/core/transformer/moe/experts.py
+1
-3
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+1
-4
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+16
-1
dcu_megatron/core/transformer/multi_latent_attention.py
dcu_megatron/core/transformer/multi_latent_attention.py
+1
-4
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+3
-0
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+10
-2
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+5
-1
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+2
-0
No files found.
dcu_megatron/adaptor/features_manager.py
View file @
5890bb4c
...
@@ -5,13 +5,17 @@ def a2a_overlap_adaptation(patches_manager):
...
@@ -5,13 +5,17 @@ def a2a_overlap_adaptation(patches_manager):
"""
"""
patches_manager: MegatronPatchesManager
patches_manager: MegatronPatchesManager
"""
"""
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.extensions.transformer_engine
import
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
from
..core.extensions.transformer_engine
import
(
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
,
)
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
...
@@ -38,23 +42,34 @@ def a2a_overlap_adaptation(patches_manager):
...
@@ -38,23 +42,34 @@ def a2a_overlap_adaptation(patches_manager):
GPTModel
)
GPTModel
)
# backward_dw
# backward_dw
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
#
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper
,
#
_get_extra_te_kwargs_wrapper,
apply_wrapper
=
True
)
#
apply_wrapper=True)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
TELayerNormColumnParallelLinear
)
TEColumnParallelLinear
.
__bases__
=
(
TELinear
,)
TERowParallelLinear
.
__bases__
=
(
TELinear
,)
if
is_te_min_version
(
"1.9.0.dev0"
):
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelGroupedLinear
,
TERowParallelGroupedLinear
from
..core.extensions.transformer_engine
import
TEGroupedLinear
from
..core.extensions.transformer_engine
import
TEGroupedLinear
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
TEGroupedLinear
)
TEColumnParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
TERowParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention'
,
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw'
,
MLASelfAttention
)
MLASelfAttention
.
backward_dw
,
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP'
,
create_dummy
=
True
)
MLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP.backward_dw'
,
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP'
,
MLP
.
backward_dw
,
TEGroupedMLP
)
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer'
,
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw'
,
MoELayer
)
TEGroupedMLP
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw'
,
MoELayer
.
backward_dw
,
create_dummy
=
True
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
5890bb4c
...
@@ -104,7 +104,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -104,7 +104,7 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
gpt_model_forward
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
...
@@ -122,12 +122,12 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -122,12 +122,12 @@ class CoreAdaptation(MegatronAdaptationABC):
MLATransformerConfigPatch
)
MLATransformerConfigPatch
)
# Moe
# Moe
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper
=
True
)
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper
=
True
)
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
5890bb4c
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
import
dataclasses
import
dataclasses
import
transformer_engine
as
te
import
transformer_engine
as
te
from
functools
import
wraps
from
typing
import
Any
,
Optional
,
Callable
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
...
@@ -18,7 +19,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
...
@@ -18,7 +19,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
get_context_parallel_group
,
get_context_parallel_group
,
get_hierarchical_context_parallel_groups
,
get_hierarchical_context_parallel_groups
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
...
@@ -29,7 +29,7 @@ def _get_extra_te_kwargs_wrapper(fn):
...
@@ -29,7 +29,7 @@ def _get_extra_te_kwargs_wrapper(fn):
@
wraps
(
fn
)
@
wraps
(
fn
)
def
wrapper
(
config
:
TransformerConfig
):
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
get
(
"split_bw"
,
False
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
return
extra_transformer_engine_kwargs
return
extra_transformer_engine_kwargs
return
wrapper
return
wrapper
...
@@ -66,7 +66,7 @@ class TELinear(MegatronCoreTELinear):
...
@@ -66,7 +66,7 @@ class TELinear(MegatronCoreTELinear):
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
super
().
__init__
(
...
@@ -109,7 +109,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
...
@@ -109,7 +109,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
super
().
__init__
(
...
@@ -314,7 +314,7 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -314,7 +314,7 @@ if is_te_min_version("1.9.0.dev0"):
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
super
().
__init__
(
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
5890bb4c
import
contextlib
import
contextlib
import
weakref
import
weakref
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
from
collections
import
OrderedDict
from
typing
import
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core.pipeline_parallel.combined_1f1b
import
(
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.utils
import
deprecate_inference_params
from
dcu_megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllPerBatchState
from
dcu_megatron.core.pipeline_parallel.combined_1f1b
import
(
AbstractSchedulePlan
,
AbstractSchedulePlan
,
ScheduleNode
,
ScheduleNode
,
get_com_stream
,
get_com_stream
,
get_comp_stream
,
make_viewless
,
make_viewless
,
)
)
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer.module
import
float16_to_fp32
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllPerBatchState
def
weak_method
(
method
):
def
weak_method
(
method
):
...
@@ -43,6 +48,7 @@ class PreProcessNode(ScheduleNode):
...
@@ -43,6 +48,7 @@ class PreProcessNode(ScheduleNode):
input_ids
=
self
.
model_chunk_state
.
input_ids
input_ids
=
self
.
model_chunk_state
.
input_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
inference_context
=
self
.
model_chunk_state
.
inference_context
inference_context
=
self
.
model_chunk_state
.
inference_context
inference_params
=
self
.
model_chunk_state
.
inference_params
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
...
@@ -121,22 +127,6 @@ class PostProcessNode(ScheduleNode):
...
@@ -121,22 +127,6 @@ class PostProcessNode(ScheduleNode):
self
.
gpt_model
=
gpt_model
self
.
gpt_model
=
gpt_model
self
.
model_chunk_state
=
model_chunk_state
self
.
model_chunk_state
=
model_chunk_state
state
.
input_ids
=
input_ids
state
.
position_ids
=
position_ids
state
.
attention_mask
=
attention_mask
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
inference_context
=
inference_context
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
inference_params
=
inference_params
state
.
loss_mask
=
loss_mask
state
.
context
=
None
state
.
context_mask
=
None
state
.
attention_bias
=
None
def
forward_impl
(
self
,
hidden_states
):
def
forward_impl
(
self
,
hidden_states
):
gpt_model
=
self
.
gpt_model
gpt_model
=
self
.
gpt_model
...
@@ -145,11 +135,13 @@ class PostProcessNode(ScheduleNode):
...
@@ -145,11 +135,13 @@ class PostProcessNode(ScheduleNode):
labels
=
self
.
model_chunk_state
.
labels
labels
=
self
.
model_chunk_state
.
labels
loss_mask
=
self
.
model_chunk_state
.
loss_mask
loss_mask
=
self
.
model_chunk_state
.
loss_mask
attention_mask
=
self
.
model_chunk_state
.
attention_mask
attention_mask
=
self
.
model_chunk_state
.
attention_mask
decoder_input
=
self
.
model_chunk_state
.
decoder_input
inference_params
=
self
.
model_chunk_state
.
inference_params
inference_params
=
self
.
model_chunk_state
.
inference_params
rotary_pos_emb
=
self
.
model_chunk_state
.
rotary_pos_emb
rotary_pos_emb
=
self
.
model_chunk_state
.
rotary_pos_emb
rotary_pos_cos
=
self
.
model_chunk_state
.
rotary_pos_cos
rotary_pos_cos
=
self
.
model_chunk_state
.
rotary_pos_cos
rotary_pos_sin
=
self
.
model_chunk_state
.
rotary_pos_sin
rotary_pos_sin
=
self
.
model_chunk_state
.
rotary_pos_sin
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
extra_block_kwargs
=
self
.
model_chunk_state
.
extra_block_kwargs
sequence_len_offset
=
self
.
model_chunk_state
.
sequence_len_offset
sequence_len_offset
=
self
.
model_chunk_state
.
sequence_len_offset
runtime_gather_output
=
self
.
model_chunk_state
.
runtime_gather_output
runtime_gather_output
=
self
.
model_chunk_state
.
runtime_gather_output
inference_context
=
self
.
model_chunk_state
.
inference_context
inference_context
=
self
.
model_chunk_state
.
inference_context
...
@@ -267,6 +259,9 @@ class TransformerLayerNode(ScheduleNode):
...
@@ -267,6 +259,9 @@ class TransformerLayerNode(ScheduleNode):
def
backward_impl
(
self
,
outputs
,
output_grad
):
def
backward_impl
(
self
,
outputs
,
output_grad
):
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
grads
=
output_grad
+
detached_grad
grads
=
output_grad
+
detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
self
.
before_detached
=
None
self
.
before_detached
=
None
self
.
detached
=
None
self
.
detached
=
None
...
@@ -296,7 +291,6 @@ class MoeAttnNode(TransformerLayerNode):
...
@@ -296,7 +291,6 @@ class MoeAttnNode(TransformerLayerNode):
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
permuted_probs
,
permuted_probs
,
probs
,
)
=
self
.
layer
.
_submodule_attention_router_compound_forward
(
)
=
self
.
layer
.
_submodule_attention_router_compound_forward
(
hidden_states
,
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -312,7 +306,6 @@ class MoeAttnNode(TransformerLayerNode):
...
@@ -312,7 +306,6 @@ class MoeAttnNode(TransformerLayerNode):
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
# detached here
# detached here
self
.
common_state
.
probs
=
self
.
detach
(
probs
)
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
...
@@ -334,7 +327,7 @@ class MoeDispatchNode(TransformerLayerNode):
...
@@ -334,7 +327,7 @@ class MoeDispatchNode(TransformerLayerNode):
)
)
# release tensor not used by backward
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
# inputs.untyped_storage().resize_(0)
self
.
common_state
.
tokens_per_expert
=
=
tokens_per_expert
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
global_input_tokens
,
global_probs
return
global_input_tokens
,
global_probs
...
@@ -345,7 +338,7 @@ class MoeMlPNode(TransformerLayerNode):
...
@@ -345,7 +338,7 @@ class MoeMlPNode(TransformerLayerNode):
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
,
global_prob
,
pre_mlp_layernorm_output
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
,
global_prob
s
,
pre_mlp_layernorm_output
)
)
assert
mlp_bias
is
None
assert
mlp_bias
is
None
...
@@ -372,9 +365,7 @@ class MoeCombineNode(TransformerLayerNode):
...
@@ -372,9 +365,7 @@ class MoeCombineNode(TransformerLayerNode):
)
)
cur_stream
=
torch
.
cuda
.
current_stream
()
cur_stream
=
torch
.
cuda
.
current_stream
()
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
probs
.
record_stream
(
cur_stream
)
self
.
common_state
.
residual
=
None
self
.
common_state
.
residual
=
None
self
.
common_state
.
probs
=
None
return
output
return
output
...
@@ -554,21 +545,18 @@ def schedule_layer_1f1b(
...
@@ -554,21 +545,18 @@ def schedule_layer_1f1b(
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
pre_forward
is
not
None
:
if
pre_forward
is
not
None
:
assert
f_input
is
None
assert
f_input
is
None
# combine from last iter
# combine from last iter
f_input
=
pre_forward
()
f_input
=
pre_forward
()
del
pre_forward
del
pre_forward
if
pre_backward
is
not
None
:
if
pre_backward
is
not
None
:
# attn backward from last iter
# attn backward from last iter
assert
b_grad
is
None
assert
b_grad
is
None
b_grad
=
pre_backward
()
b_grad
=
pre_backward
()
del
pre_backward
del
pre_backward
if
b_layer
is
not
None
:
if
b_layer
is
not
None
:
with
b_context
:
with
b_context
:
b_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
b_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
...
@@ -577,7 +565,6 @@ def schedule_layer_1f1b(
...
@@ -577,7 +565,6 @@ def schedule_layer_1f1b(
pre_backward_dw
()
pre_backward_dw
()
del
pre_backward_dw
del
pre_backward_dw
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
...
@@ -592,7 +579,6 @@ def schedule_layer_1f1b(
...
@@ -592,7 +579,6 @@ def schedule_layer_1f1b(
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
b_layer
.
mlp
.
dw
()
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
...
@@ -614,7 +600,6 @@ def schedule_layer_1f1b(
...
@@ -614,7 +600,6 @@ def schedule_layer_1f1b(
with
b_context
:
with
b_context
:
b_layer
.
attn
.
dw
()
b_layer
.
attn
.
dw
()
if
f_layer
and
b_layer
:
if
f_layer
and
b_layer
:
return
next_iter_pre_forward
,
next_iter_pre_backward
,
next_iter_pre_backward_dw
return
next_iter_pre_forward
,
next_iter_pre_backward
,
next_iter_pre_backward_dw
else
:
else
:
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
5890bb4c
...
@@ -7,10 +7,10 @@ from functools import wraps
...
@@ -7,10 +7,10 @@ from functools import wraps
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
5890bb4c
...
@@ -427,7 +427,7 @@ def forward_backward_step(
...
@@ -427,7 +427,7 @@ def forward_backward_step(
if
f_model
:
if
f_model
:
with
f_context
:
with
f_context
:
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
not
collect_non_loss_data
:
if
not
collect_non_loss_data
:
loss_node
=
ScheduleNode
(
loss_node
=
ScheduleNode
(
loss_func
,
loss_func
,
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
5890bb4c
...
@@ -2,17 +2,13 @@ import contextlib
...
@@ -2,17 +2,13 @@ import contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
from
megatron.core.utils
import
(
from
megatron.core.utils
import
(
drain_embedding_wgrad_compute
,
get_attr_wrapped_model
,
get_attr_wrapped_model
,
get_model_config
,
get_model_config
,
get_model_type
,
get_model_type
,
...
@@ -32,6 +28,18 @@ from megatron.core.pipeline_parallel.schedules import (
...
@@ -32,6 +28,18 @@ from megatron.core.pipeline_parallel.schedules import (
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
def
set_current_microbatch
(
model
,
microbatch_id
):
"""Set the current microbatch."""
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
decoder
.
current_microbatch
=
microbatch_id
def
get_pp_rank_microbatches
(
def
get_pp_rank_microbatches
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
):
):
...
@@ -541,7 +549,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -541,7 +549,7 @@ def forward_backward_pipelining_with_interleaving(
)
)
# forward step
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
...
@@ -573,7 +581,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -573,7 +581,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync
()
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
b_input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
b_input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
...
@@ -679,7 +687,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -679,7 +687,6 @@ def forward_backward_pipelining_with_interleaving(
post_backward
=
post_backward
,
post_backward
=
post_backward
,
)
)
else
:
else
:
output_tensor
=
None
input_tensor_grad
=
None
input_tensor_grad
=
None
if
f_virtual_microbatch_id
is
not
None
:
if
f_virtual_microbatch_id
is
not
None
:
# forward pass
# forward pass
...
@@ -704,7 +711,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -704,7 +711,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
if
post_backward
is
not
None
:
if
post_backward
is
not
None
:
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
,
input_tensor_grad
return
output_tensor
if
f_virtual_microbatch_id
is
not
None
else
None
,
input_tensor_grad
# Run warmup forward passes.
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
...
@@ -890,6 +897,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -890,6 +897,7 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
# Run 1F1B in steady state.
# Run 1F1B in steady state.
output_tensor
=
None
for
k
in
range
(
num_microbatches_remaining
):
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
forward_k
=
k
+
num_warmup_microbatches
...
...
dcu_megatron/core/transformer/mlp.py
View file @
5890bb4c
from
megatron.core.transformer.mlp
import
MLP
as
MegatronCoreMLP
class
MLP
():
class
MLP
(
MegatronCoreMLP
):
def
backward_dw
(
self
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
dcu_megatron/core/transformer/moe/experts.py
View file @
5890bb4c
from
megatron.core.transformer.experts
import
TEGroupedMLP
as
MegatronCoreTEGroupedMLP
class
TEGroupedMLP
():
class
TEGroupedMLP
(
MegatronCoreTEGroupedMLP
):
def
backward_dw
(
self
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
dcu_megatron/core/transformer/moe/moe_layer.py
View file @
5890bb4c
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
as
MegatronCoreMoELayer
class
MoELayer
():
class
MoELayer
(
MegatronCoreMoELayer
):
def
backward_dw
(
self
):
def
backward_dw
(
self
):
self
.
experts
.
backward_dw
()
self
.
experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
5890bb4c
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
import
torch
from
megatron.core.tensor_parallel
import
(
all_to_all
,
gather_from_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
from
megatron.core.transformer.moe.moe_utils
import
(
permute
,
sort_chunks_by_idxs
,
unpermute
,
)
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
...
@@ -303,7 +318,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -303,7 +318,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
"""
assert
bias
is
None
,
"Bias is not supported in MoEAlltoAllTokenDispatcher"
assert
bias
is
None
,
"Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
permutated_local_input_tokens
=
self
.
combine_all_to_all
(
hidden_states
)
permutated_local_input_tokens
=
self
.
combine_all_to_all
(
hidden_states
)
output
=
self
.
combine_postprocess
(
permutated_local_input_tokens
)
output
=
self
.
combine_postprocess
(
permutated_local_input_tokens
)
...
...
dcu_megatron/core/transformer/multi_latent_attention.py
View file @
5890bb4c
from
megatron.core.transformer.multi_latent_attention
import
MLASelfAttention
as
MegatronCoreMLASelfAttention
class
MLASelfAttention
():
class
MLASelfAttention
(
MegatronCoreMLASelfAttention
):
"""MLA Self-attention layer class
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
Self-attention layer takes input with size [s, b, h]
...
...
dcu_megatron/core/transformer/transformer_block.py
View file @
5890bb4c
...
@@ -9,7 +9,7 @@ def transformer_block_init_wrapper(fn):
...
@@ -9,7 +9,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
get
attr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
if
has
attr
(
config
,
"mtp_num_layers"
)
and
config
.
mtp_num_layers
is
not
None
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
5890bb4c
...
@@ -40,6 +40,9 @@ class ExtraTransformerConfig:
...
@@ -40,6 +40,9 @@ class ExtraTransformerConfig:
combined_1f1b_recipe
:
str
=
'ep_a2a'
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw
:
bool
=
False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@
dataclass
@
dataclass
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
5890bb4c
from
megatron.core
import
parallel_state
,
tensor_parallel
from
functools
import
partial
from
typing
import
Any
,
Optional
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
(
from
megatron.core.utils
import
(
deprecate_inference_params
,
deprecate_inference_params
,
make_viewless_tensor
,
make_viewless_tensor
,
)
)
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_callable_wrapper
(
def
_callable_wrapper
(
...
@@ -147,7 +156,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -147,7 +156,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
permuted_probs
,
permuted_probs
,
probs
,
]
]
return
tuple
(
outputs
)
return
tuple
(
outputs
)
...
...
dcu_megatron/legacy/model/rms_norm.py
View file @
5890bb4c
import
warnings
import
torch
import
torch
from
typing
import
Optional
from
typing
import
Optional
import
lightop
try
:
import
lightop
except
ImportError
:
warnings
.
warn
(
'lightop is not installed.'
)
from
functools
import
partial
from
functools
import
partial
from
megatron.core.utils
import
is_torch_min_version
from
megatron.core.utils
import
is_torch_min_version
...
...
dcu_megatron/training/arguments.py
View file @
5890bb4c
...
@@ -139,6 +139,8 @@ def _add_extra_moe_args(parser):
...
@@ -139,6 +139,8 @@ def _add_extra_moe_args(parser):
choices
=
[
'ep_a2a'
,
'golden'
],
choices
=
[
'ep_a2a'
,
'golden'
],
default
=
'golden'
,
default
=
'golden'
,
help
=
'Options are "ep_a2a" and "golden".'
)
help
=
'Options are "ep_a2a" and "golden".'
)
group
.
add_argument
(
'--split-bw'
,
action
=
'store_true'
,
help
=
'Split dgrad and wgrad for batch-level overlapping'
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment