Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
c964fcca
"scripts/playground/load_tokenizer.py" did not exist on "a385ee27bd0025781eba61578889e470a1c027fb"
Commit
c964fcca
authored
May 17, 2025
by
dongcl
Browse files
only transformer-engine>=2.4.0.dev supports split_bw
parent
ab2a8334
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
107 additions
and
174 deletions
+107
-174
.gitignore
.gitignore
+2
-0
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+7
-9
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+5
-5
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+2
-1
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+1
-4
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+1
-9
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+28
-3
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-10
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+61
-133
No files found.
.gitignore
0 → 100644
View file @
c964fcca
__pycache__
*.bak
dcu_megatron/adaptor/features_manager.py
View file @
c964fcca
...
...
@@ -7,7 +7,6 @@ def a2a_overlap_adaptation(patches_manager):
"""
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
...
...
@@ -32,19 +31,18 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.TransformerBlock'
,
TransformerBlock
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
GPTModel
.
build_schedule_plan
,
create_dummy
=
True
)
# backward_dw
# patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
# _get_extra_te_kwargs_wrapper,
# apply_wrapper=True)
if
is_te_min_version
(
"2.4.0.dev0"
):
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
...
...
dcu_megatron/adaptor/megatron_adaptor.py
View file @
c964fcca
...
...
@@ -101,11 +101,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
#
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
#
gpt_model_init_wrapper,
#
apply_wrapper=True)
#
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
#
gpt_model_forward)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
c964fcca
...
...
@@ -29,7 +29,8 @@ def _get_extra_te_kwargs_wrapper(fn):
@
wraps
(
fn
)
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
if
hasattr
(
config
,
"split_bw"
):
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
return
extra_transformer_engine_kwargs
return
wrapper
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
c964fcca
...
...
@@ -261,9 +261,6 @@ class TransformerLayerNode(ScheduleNode):
def
backward_impl
(
self
,
outputs
,
output_grad
):
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
grads
=
output_grad
+
detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
self
.
before_detached
=
None
self
.
detached
=
None
...
...
@@ -344,7 +341,7 @@ class MoeMlPNode(TransformerLayerNode):
)
assert
mlp_bias
is
None
# pre_mlp_layernorm_output
used
# pre_mlp_layernorm_output used
self
.
common_state
.
pre_mlp_layernorm_output
=
None
return
expert_output
,
shared_expert_output
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
c964fcca
...
...
@@ -11,7 +11,6 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
def
gpt_model_init_wrapper
(
fn
):
...
...
@@ -232,17 +231,10 @@ def gpt_model_forward(
return
loss
class
GPTModel
(
MegatronCoreGPTModel
)
:
class
GPTModel
:
"""
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
self
,
input_ids
:
Tensor
,
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
c964fcca
...
...
@@ -37,7 +37,8 @@ def set_current_microbatch(model, microbatch_id):
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
decoder
.
current_microbatch
=
microbatch_id
for
layer
in
decoder
.
layers
:
layer
.
current_microbatch
=
microbatch_id
def
get_pp_rank_microbatches
(
...
...
@@ -87,6 +88,16 @@ def get_pp_rank_microbatches(
)
def
print_rank_4
(
message
):
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
4
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
from
megatron.training
import
print_rank_0
,
print_rank_last
def
forward_backward_pipelining_with_interleaving
(
*
,
forward_step_func
,
...
...
@@ -297,6 +308,9 @@ def forward_backward_pipelining_with_interleaving(
# Both tables are indexed with virtual_microbatch_id.
microbatch_id_table
,
model_chunk_id_table
=
zip
(
*
schedule_table
)
print_rank_4
(
f
"rank last. microbatch_id_table:
{
microbatch_id_table
}
. model_chunk_id_table:
{
model_chunk_id_table
}
"
)
print_rank_0
(
f
"rank first. microbatch_id_table:
{
microbatch_id_table
}
. model_chunk_id_table:
{
model_chunk_id_table
}
"
)
def
get_model_chunk_id
(
virtual_microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
model_chunk_id
=
model_chunk_id_table
[
virtual_microbatch_id
%
total_num_microbatches
]
...
...
@@ -687,6 +701,7 @@ def forward_backward_pipelining_with_interleaving(
post_backward
=
post_backward
,
)
else
:
output_tensor
=
None
input_tensor_grad
=
None
if
f_virtual_microbatch_id
is
not
None
:
# forward pass
...
...
@@ -711,7 +726,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
if
post_backward
is
not
None
:
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
if
f_virtual_microbatch_id
is
not
None
else
None
,
input_tensor_grad
return
output_tensor
,
input_tensor_grad
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
...
...
@@ -897,11 +912,13 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
# Run 1F1B in steady state.
output_tensor
=
None
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
print_rank_0
(
f
"rank first. 1F1B in steady state:
{
k
}
/
{
num_microbatches_remaining
}
"
)
print_rank_4
(
f
"rank last. 1F1B in steady state:
{
k
}
/
{
num_microbatches_remaining
}
"
)
# Decide to checkpoint all layers' activations of the current micro-batch.
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
...
...
@@ -1053,6 +1070,9 @@ def forward_backward_pipelining_with_interleaving(
post_backward
=
pp_post_backward
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
)
print_rank_0
(
f
"rank first. 1F1B in steady state:
{
k
}
/
{
num_microbatches_remaining
}
end"
)
print_rank_4
(
f
"rank last. 1F1B in steady state:
{
k
}
/
{
num_microbatches_remaining
}
end"
)
else
:
# No p2p overlap.
backward_k
=
k
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
...
...
@@ -1109,6 +1129,11 @@ def forward_backward_pipelining_with_interleaving(
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
if
k
==
0
:
print_rank_0
(
f
"input_tensor_grad:
{
input_tensor_grad
}
"
)
print_rank_0
(
f
"rank first. 1F1B in steady state end"
)
print_rank_4
(
f
"rank last. 1F1B in steady state end"
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run cooldown backward passes (flush out pipeline).
...
...
dcu_megatron/core/transformer/transformer_block.py
View file @
c964fcca
from
functools
import
wraps
from
megatron.core.transformer.transformer_block
import
TransformerBlock
as
MegatronCoreTransformerBlock
def
transformer_block_init_wrapper
(
fn
):
@
wraps
(
fn
)
...
...
@@ -14,12 +13,3 @@ def transformer_block_init_wrapper(fn):
self
.
final_layernorm
=
None
return
wrapper
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given layer number.
"""
return
self
.
layers
[
layer_number
].
get_submodule_callables
()
dcu_megatron/core/transformer/transformer_layer.py
View file @
c964fcca
...
...
@@ -45,6 +45,65 @@ class TransformerLayer(MegatronCoreTransformerLayer):
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
def
forward
(
self
,
hidden_states
:
Tensor
,
context
:
Optional
[
Tensor
]
=
None
,
context_mask
:
Optional
[
Tensor
]
=
None
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
attention_bias
:
Optional
[
Tensor
]
=
None
,
inference_context
:
Optional
[
Any
]
=
None
,
packed_seq_params
:
Optional
[
PackedSeqParams
]
=
None
,
sequence_len_offset
:
Optional
[
Tensor
]
=
None
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
(
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
)
=
self
.
_submodule_attention_router_compound_forward
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
attention_bias
,
inference_context
,
packed_seq_params
,
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
=
self
.
_submodule_dispatch_forward
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
(
expert_output
,
shared_expert_output
,
mlp_bias
)
=
self
.
_submodule_moe_forward
(
tokens_per_expert
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
)
expert_output
=
self
.
_submodule_combine_forward
(
expert_output
)[
0
]
output
=
self
.
_submodule_post_combine_forward
(
expert_output
,
shared_expert_output
,
mlp_bias
,
hidden_states
)
return
output
,
None
def
_submodule_attention_forward
(
self
,
hidden_states
:
Tensor
,
...
...
@@ -182,14 +241,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_prob
,
pre_mlp_layernorm_output
):
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_prob
s
,
pre_mlp_layernorm_output
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output
=
None
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
=
(
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_prob
)
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_prob
s
)
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
...
...
@@ -221,141 +280,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
def
_submodule_attention_backward
(
self
,
hidden_states
,
pre_mlp_layernorm_output
,
detached_inputs
):
pre_mlp_layernorm_output
.
backward
(
detached_inputs
[
1
].
grad
)
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_attention_router_compound_backward
(
self
,
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
probs
,
detached_inputs
,
):
permutated_local_input_tokens
.
backward
(
detached_inputs
[
3
].
grad
)
probs
.
backward
(
detached_inputs
[
4
].
grad
)
# tokens_per_expert.backward(detached_inputs[2].grad)
pre_mlp_layernorm_output
.
backward
(
detached_inputs
[
1
].
grad
)
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_dispatch_backward
(
self
,
global_input_tokens
,
detached_inputs
):
global_input_tokens
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_dense_backward
(
self
,
output
,
detached_inputs
):
output
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_moe_backward
(
self
,
expert_output
,
shared_expert_output
,
mlp_bias
,
detached_inputs
):
expert_output
.
backward
(
detached_inputs
[
0
].
grad
)
shared_expert_output
.
backward
(
detached_inputs
[
1
].
grad
)
if
mlp_bias
is
not
None
:
mlp_bias
.
backward
(
detached_inputs
[
2
].
grad
)
def
_submodule_combine_backward
(
self
,
hidden_states
,
detached_inputs
):
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_post_combine_backward
(
self
,
output
,
output_grad
):
output
.
backward
(
output_grad
)
def
_submodule_attention_router_compound_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_attention_router_compound_dw
(
self
):
self
.
self_attention
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_dispatch_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_mlp_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_combine_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_identity_forward
(
self
,
*
args
):
return
args
def
_submodule_identity_backward
(
self
,
*
args
):
pass
def
get_submodule_callables
(
self
):
"""
Returns a dictionary of submodule callables for the transformer layer.
"""
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
def
get_func_with_default
(
func
,
default_func
):
if
isinstance
(
self
.
mlp
,
MoELayer
):
return
func
return
default_func
attention_func
=
get_func_with_default
(
self
.
_submodule_attention_router_compound_forward
,
self
.
_submodule_attention_forward
)
attention_backward_func
=
get_func_with_default
(
self
.
_submodule_attention_router_compound_backward
,
self
.
_submodule_attention_backward
)
dispatch_func
=
get_func_with_default
(
self
.
_submodule_dispatch_forward
,
self
.
_submodule_identity_forward
)
dispatch_backward_func
=
get_func_with_default
(
self
.
_submodule_dispatch_backward
,
self
.
_submodule_identity_backward
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_submodule_dense_forward
)
mlp_backward_func
=
get_func_with_default
(
self
.
_submodule_moe_backward
,
self
.
_submodule_dense_backward
)
combine_func
=
get_func_with_default
(
self
.
_submodule_combine_forward
,
self
.
_submodule_identity_forward
)
combine_backward_func
=
get_func_with_default
(
self
.
_submodule_combine_backward
,
self
.
_submodule_identity_backward
)
post_combine_func
=
get_func_with_default
(
self
.
_submodule_post_combine_forward
,
self
.
_submodule_identity_forward
)
post_combine_backward_func
=
get_func_with_default
(
self
.
_submodule_post_combine_backward
,
self
.
_submodule_identity_backward
)
callables
=
TransformerLayerSubmoduleCallables
(
attention
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
attention_func
,
skip_detach
=
True
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
attention_backward_func
),
# dgrad=partial(self._callable_wrapper, False,self._submodule_attention_router_compound_dgrad),
dw
=
partial
(
self
.
_callable_wrapper
,
False
,
self
.
_submodule_attention_router_compound_dw
),
),
dispatch
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
dispatch_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
dispatch_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_dispatch_dgrad),
),
mlp
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
mlp_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
mlp_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_mlp_dgrad),
dw
=
partial
(
self
.
_callable_wrapper
,
False
,
self
.
_submodule_mlp_dw
),
),
combine
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
combine_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
combine_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_combine_dgrad),
),
post_combine
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
post_combine_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
post_combine_backward_func
),
),
)
return
callables
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment