Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
f3434cc7
Commit
f3434cc7
authored
Jun 03, 2025
by
dongcl
Browse files
fix dualpipev broadcast error
parent
7c63d1a4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
421 additions
and
147 deletions
+421
-147
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+0
-72
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
...or/features_manager/pipeline_parallel/pipeline_feature.py
+1
-1
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+22
-5
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
...tron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
+10
-4
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+81
-63
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+172
-1
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+24
-0
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+111
-0
pretrain_gpt.py
pretrain_gpt.py
+0
-1
No files found.
dcu_megatron/adaptor/features_manager.py
deleted
100644 → 0
View file @
7c63d1a4
from
megatron.core.utils
import
is_te_min_version
def
a2a_overlap_adaptation
(
patches_manager
):
"""
patches_manager: MegatronPatchesManager
"""
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.extensions.transformer_engine
import
(
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
,
)
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.moe_layer
import
MoELayer
# num_warmup_microbatches + 1
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
GPTModel
.
build_schedule_plan
,
create_dummy
=
True
)
# backward_dw
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
TEColumnParallelLinear
.
__bases__
=
(
TELinear
,)
TERowParallelLinear
.
__bases__
=
(
TELinear
,)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelGroupedLinear
,
TERowParallelGroupedLinear
from
..core.extensions.transformer_engine
import
TEGroupedLinear
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
TEColumnParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
TERowParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw'
,
MLASelfAttention
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP.backward_dw'
,
MLP
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw'
,
TEGroupedMLP
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw'
,
MoELayer
.
backward_dw
,
create_dummy
=
True
)
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
View file @
f3434cc7
...
...
@@ -56,7 +56,7 @@ class PipelineFeature(AbstractFeature):
patch_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving'
,
forward_backward_pipelining_with_cutinhalf
)
patch_manager
.
register_patch
(
'megatron.
legacy.model
.module.Float16Module.forward'
,
dualpipev_fp16forward
)
'megatron.
core.transformer
.module.Float16Module.forward'
,
dualpipev_fp16forward
)
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.get_num_layers_to_build'
,
get_num_layers_to_build
)
patch_manager
.
register_patch
(
...
...
dcu_megatron/adaptor/megatron_adaptor.py
View file @
f3434cc7
import
os
import
abc
import
sys
import
types
import
argparse
import
torch
...
...
@@ -9,6 +7,10 @@ from megatron.core.utils import is_te_min_version
from
.features_manager
import
ADAPTOR_FEATURES
from
.patch_utils
import
MegatronPatchesManager
from
dcu_megatron.training.arguments
import
process_adaptor_args
_ARGS
=
None
def
add_args
(
args
,
key
,
value
):
...
...
@@ -42,7 +44,7 @@ def get_adaptor_args():
global
_ARGS
if
_ARGS
is
None
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Adaptor Arguments'
,
allow_abbrev
=
False
)
_ARGS
,
unknown
=
process_args
(
parser
).
parse_known_args
()
_ARGS
,
unknown
=
process_
adaptor_
args
(
parser
).
parse_known_args
()
parser_unknown_args
(
_ARGS
,
unknown
)
return
_ARGS
...
...
@@ -119,7 +121,7 @@ def adaptation_l2(patches_manager, adaptor_args):
"""
for
feature
in
ADAPTOR_FEATURES
:
if
getattr
(
adaptor_args
,
feature
.
feature_name
,
None
)
and
feature
.
optimization_level
==
2
:
feature
.
register_patches
(
patches_manager
,
mindspeed
_args
)
feature
.
register_patches
(
patches_manager
,
adaptor
_args
)
class
MegatronAdaptationABC
:
...
...
@@ -161,6 +163,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_layer
import
get_transformer_layer_offset
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
...
...
@@ -187,6 +190,10 @@ class CoreAdaptation(MegatronAdaptationABC):
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
# support dualpipev
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_layer.get_transformer_layer_offset'
,
get_transformer_layer_offset
)
def
patch_core_extentions
(
self
):
import
transformer_engine
as
te
...
...
@@ -250,8 +257,9 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
from
..training.training
import
train
,
build_train_valid_test_data_iterators_wrapper
from
..training.initialize
import
_set_random_seed
from
..training.utils
import
get_batch_on_this_tp_rank
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
...
...
@@ -270,6 +278,15 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
# support dualpipev, two data iterators
MegatronAdaptation
.
register
(
'megatron.training.training.build_train_valid_test_data_iterators'
,
build_train_valid_test_data_iterators_wrapper
,
apply_wrapper
=
True
)
# support dualpipev, broadcast loss_mask and labels
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
def
patch_miscellaneous
(
self
):
from
..training.arguments
import
parse_args
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
View file @
f3434cc7
...
...
@@ -4,6 +4,7 @@ from typing import List, Optional
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.utils
import
get_model_config
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
megatron.core.distributed
import
DistributedDataParallelConfig
from
megatron.core.distributed
import
DistributedDataParallel
as
DDP
from
megatron.core.enums
import
ModelType
...
...
@@ -14,7 +15,10 @@ from megatron.core.transformer.module import fp32_to_float16, float16_to_fp32
from
megatron.core.num_microbatches_calculator
import
get_num_microbatches
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core
import
parallel_state
from
megatron.core.distributed.finalize_model_grads
import
_allreduce_layernorm_grads
from
megatron.training.utils
import
(
logical_and_across_model_parallel_group
,
reduce_max_stat_across_model_parallel_group
)
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules
import
get_dualpipe_chunk
...
...
@@ -82,7 +86,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
config
=
get_model_config
(
model
[
0
])
model
=
[
Float16Module
(
config
,
model_module
)
for
model_module
in
model
]
if
wrap_with_ddp
:
config
=
get_model_config
(
model
[
0
])
...
...
@@ -217,10 +222,11 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
def
_allreduce_embedding_grads_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
get_args
().
schedules_method
==
'dualpipev'
:
args
=
get_args
()
if
args
.
schedule_method
==
'dualpipev'
:
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if
not
get_
args
()
.
untie_embeddings_and_output_weights
:
if
not
args
.
untie_embeddings_and_output_weights
:
raise
NotImplementedError
else
:
return
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
f3434cc7
...
...
@@ -19,9 +19,8 @@ from megatron.core.utils import (
from
megatron.core.pipeline_parallel.schedules
import
clear_embedding_activation_buffer
,
deallocate_output_tensor
from
megatron.core
import
ModelParallelConfig
from
megatron.core.pipeline_parallel.p2p_communication
import
_communicate
from
megatron.core.pipeline_parallel.schedules
import
backward_step
,
set_current_microbatch
,
custom_backward
,
finish_embedding_wgrad_compute
from
megatron.core.models.gpt
import
GPTModel
from
mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store
import
WeightGradStore
from
megatron.core.pipeline_parallel.schedules
import
backward_step
,
set_current_microbatch
,
finish_embedding_wgrad_compute
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# Types
...
...
@@ -115,7 +114,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
return
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
torch
.
Tensor
:
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
,
step
=-
1
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
...
...
@@ -568,8 +567,6 @@ def forward_backward_pipelining_with_cutinhalf(
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
input_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
model_graphs
=
[[],
[]]
logits_inputs
=
[]
forward_data_store
=
[]
master_chunk_id
=
0
...
...
@@ -584,7 +581,7 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)[
0
]
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
step
=
0
)[
0
]
fwd_wait_handles_warmup
=
None
# Run warmup forward passes
...
...
@@ -627,10 +624,10 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk
=
None
fwd_wait_handles_send
=
None
for
i
in
range
(
schedule
[
'interleaved_forward'
][
rank
]):
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
...
...
@@ -659,14 +656,16 @@ def forward_backward_pipelining_with_cutinhalf(
master_cur_microbatch
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
for
req
in
fwd_wait_handles_send
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
input_tensor_slave_chunk
=
output_tensor
input_tensor_slave_chunk
=
output_tensor
.
detach
()
input_tensor_slave_chunk
.
requires_grad
=
True
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
...
...
@@ -678,15 +677,17 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_warmup
is
not
None
:
for
req
in
fwd_wait_handles_warmup
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_warmup
=
None
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
...
...
@@ -733,17 +734,21 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_send
=
output_tensor
fwd_wait_handles_send
=
send_forward
(
output_tensor_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
else
:
# custom_backward requires output_tensor.numel() == 1
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
# Run 1b1w1f stages for slave chunk
bwd_wait_handles
=
None
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
WeightGradStore
.
start_decouple
()
#
WeightGradStore.start_decouple()
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
...
...
@@ -752,11 +757,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
WeightGradStore
.
end_decouple
()
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
# WeightGradStore.end_decouple()
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
...
...
@@ -765,22 +766,28 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
fwd_wait_handles_send
is
not
None
:
for
req
in
fwd_wait_handles_send
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
# If asynchronous, the memory will rise.
input_tensor_slave_chunk
,
recv_forward_handle
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
# 1w: Weight Grad Compute
WeightGradStore
.
pop
()
#
WeightGradStore.pop()
if
recv_forward_handle
is
not
None
:
for
req
in
recv_forward_handle
:
req
.
wait
()
for
req
,
handle
in
recv_forward_handle
.
items
():
if
handle
is
not
None
:
handle
.
wait
()
recv_forward_handle
=
None
# 1F: Forward pass
...
...
@@ -816,7 +823,7 @@ def forward_backward_pipelining_with_cutinhalf(
# Run overlaping f&bw stages
fwd_model_chunk_id
=
master_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
for
_
in
range
(
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
+
schedule
[
'interleaved_backward'
][
rank
]):
for
step_id
in
range
(
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
+
schedule
[
'interleaved_backward'
][
rank
]):
only_bwd
=
False
if
fwd_model_chunk_id
==
master_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
only_bwd
=
True
...
...
@@ -853,24 +860,37 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
firstFB_no_overlp_handle
is
not
None
:
for
req
in
firstFB_no_overlp_handle
:
req
.
wait
()
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
firstFB_no_overlp_handle
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
...
...
@@ -883,8 +903,9 @@ def forward_backward_pipelining_with_cutinhalf(
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
...
...
@@ -896,21 +917,15 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向send
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
# only run backward
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
...
...
@@ -951,26 +966,28 @@ def forward_backward_pipelining_with_cutinhalf(
for
i
in
range
(
pp_size
):
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles_recv
is
not
None
:
for
req
in
bwd_wait_handles_recv
:
req
.
wait
()
for
req
,
req_handle
in
bwd_wait_handles_recv
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles_recv
=
None
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
start_decouple
()
#
if not args.dualpipe_no_dw_detach:
#
WeightGradStore.start_decouple()
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
end_decouple
()
#
if not args.dualpipe_no_dw_detach:
#
WeightGradStore.end_decouple()
if
i
==
pp_size
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
...
...
@@ -988,18 +1005,19 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
WeightGradStore
.
flush_chunk_grad
()
if
i
>=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
WeightGradStore
.
pop_single
()
#
WeightGradStore.flush_chunk_grad()
#
if i >= schedule['cooldown'][rank][0] - 1:
#
WeightGradStore.pop_single()
for
_
in
range
(
schedule
[
'cooldown'
][
rank
][
2
]
-
1
):
WeightGradStore
.
pop_single
()
#
for _ in range(schedule['cooldown'][rank][2] - 1):
#
WeightGradStore.pop_single()
assert
WeightGradStore
.
weight_grad_queue
.
empty
()
#
assert WeightGradStore.weight_grad_queue.empty()
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
f3434cc7
...
...
@@ -2,7 +2,8 @@ from typing import Any, Optional
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
,
parallel_state
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
(
deprecate_inference_params
,
...
...
@@ -11,6 +12,176 @@ from megatron.core.utils import (
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
megatron.core.transformer.transformer_config
import
TransformerConfig
def
get_transformer_layer_offset
(
config
:
TransformerConfig
):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
args
=
get_args
()
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
not
parallel_state
.
is_inside_encoder
():
pp_decoder_start
=
parallel_state
.
get_pipeline_model_parallel_decoder_start
()
if
pp_decoder_start
is
not
None
:
pipeline_rank
=
pipeline_rank
-
pp_decoder_start
if
config
.
pipeline_model_parallel_size
>
1
:
if
(
config
.
num_layers_in_first_pipeline_stage
is
not
None
or
config
.
num_layers_in_last_pipeline_stage
is
not
None
):
# Calculate number of pipeline stages to distribute the remaining Transformer
# layers after deducting the Transformer layers in the first or the last stages
middle_pipeline_stages
=
config
.
pipeline_model_parallel_size
if
args
.
schedule_method
==
'dualpipev'
:
middle_pipeline_stages
*=
2
middle_pipeline_stages
-=
sum
(
[
1
if
x
is
not
None
else
0
for
x
in
(
config
.
num_layers_in_first_pipeline_stage
,
config
.
num_layers_in_last_pipeline_stage
,
)
]
)
# Calculate layers to distribute in each pipeline stage. If the
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
# are not set, we will not enable uneven pipeline. All layers will be treated
# as middle layers.
num_layers_in_first_pipeline_stage
=
(
0
if
config
.
num_layers_in_first_pipeline_stage
is
None
else
config
.
num_layers_in_first_pipeline_stage
)
num_layers_in_last_pipeline_stage
=
(
0
if
config
.
num_layers_in_last_pipeline_stage
is
None
else
config
.
num_layers_in_last_pipeline_stage
)
middle_num_layers
=
(
config
.
num_layers
-
num_layers_in_first_pipeline_stage
-
num_layers_in_last_pipeline_stage
)
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
vp_rank
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
vp_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
# num_layers_in_last_pipeline_stage are not set, all pipeline stages
# will be treated as middle pipeline stages in the calculation
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
=
(
0
if
config
.
num_layers_in_first_pipeline_stage
is
None
else
config
.
num_layers_in_first_pipeline_stage
//
vp_size
)
num_layers_per_virtual_model_chunk_in_last_pipeline_stage
=
(
0
if
config
.
num_layers_in_last_pipeline_stage
is
None
else
config
.
num_layers_in_last_pipeline_stage
//
vp_size
)
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
=
(
middle_num_layers
//
vp_size
)
# First stage + middle stage + last stage
total_virtual_chunks
=
(
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
+
num_layers_per_virtual_model_chunk_in_last_pipeline_stage
)
# Calculate the layer offset with interleaved uneven pipeline parallelism
if
pipeline_rank
==
0
:
offset
=
vp_rank
*
total_virtual_chunks
else
:
offset
=
(
vp_rank
*
total_virtual_chunks
+
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+
(
pipeline_rank
-
1
)
*
(
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
//
middle_pipeline_stages
)
)
else
:
if
middle_pipeline_stages
>
0
:
num_layers_per_pipeline_rank
=
middle_num_layers
//
middle_pipeline_stages
else
:
num_layers_per_pipeline_rank
=
0
middle_pipeline_rank
=
(
pipeline_rank
if
config
.
num_layers_in_first_pipeline_stage
is
None
else
pipeline_rank
-
1
)
if
not
getattr
(
args
,
'dualpipev_first_chunk'
,
True
):
middle_pipeline_rank
=
(
config
.
pipeline_model_parallel_size
if
config
.
num_layers_in_first_pipeline_stage
is
None
else
config
.
pipeline_model_parallel_size
-
1
)
+
(
config
.
pipeline_model_parallel_size
-
(
pipeline_rank
+
1
))
if
getattr
(
args
,
'dualpipev_first_chunk'
,
True
)
and
pipeline_rank
==
0
:
offset
=
0
else
:
offset
=
(
middle_pipeline_rank
*
num_layers_per_pipeline_rank
)
+
num_layers_in_first_pipeline_stage
else
:
num_layers
=
config
.
num_layers
# Increase the number of layers by one if we include the embedding (loss)
# layer into pipeline parallelism partition and placement
if
config
.
account_for_embedding_in_pipeline_split
:
num_layers
+=
1
if
config
.
account_for_loss_in_pipeline_split
:
num_layers
+=
1
num_layers_per_pipeline_rank
=
num_layers
//
config
.
pipeline_model_parallel_size
if
args
.
schedule_method
==
'dualpipev'
:
num_layers_per_pipeline_rank
=
num_layers_per_pipeline_rank
//
2
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
vp_rank
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
vp_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
num_layers_per_virtual_rank
=
num_layers_per_pipeline_rank
//
vp_size
total_virtual_chunks
=
num_layers
//
vp_size
offset
=
vp_rank
*
total_virtual_chunks
+
(
pipeline_rank
*
num_layers_per_virtual_rank
)
# Reduce the offset of embedding layer from the total layer number
if
(
config
.
account_for_embedding_in_pipeline_split
and
not
parallel_state
.
is_pipeline_first_stage
()
):
offset
-=
1
else
:
if
getattr
(
args
,
'dualpipev_first_chunk'
,
True
):
offset
=
pipeline_rank
*
num_layers_per_pipeline_rank
else
:
offset
=
num_layers
-
(
pipeline_rank
+
1
)
*
num_layers_per_pipeline_rank
# Reduce the offset of embedding layer from the total layer number
if
(
config
.
account_for_embedding_in_pipeline_split
and
not
parallel_state
.
is_pipeline_first_stage
()
):
offset
-=
1
else
:
offset
=
0
return
offset
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
...
...
dcu_megatron/training/training.py
View file @
f3434cc7
import
gc
import
sys
from
functools
import
wraps
import
torch.distributed
import
torch
...
...
@@ -53,6 +54,29 @@ from megatron.training.training import (
stimer
=
StragglerDetector
()
def
build_train_valid_test_data_iterators_wrapper
(
build_train_valid_test_data_iterators_func
):
@
wraps
(
build_train_valid_test_data_iterators_func
)
def
wrapper
(
train_valid_test_dataset_provider
):
args
=
get_args
()
if
args
.
schedule_method
==
'dualpipev'
:
train_data_iterator
=
[]
valid_data_iterator
=
[]
test_data_iterator
=
[]
for
_
in
range
(
2
):
iterators
=
build_train_valid_test_data_iterators_func
(
train_valid_test_dataset_provider
)
train_data_iterator
.
append
(
iterators
[
0
])
valid_data_iterator
.
append
(
iterators
[
1
])
test_data_iterator
.
append
(
iterators
[
2
])
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators_func
(
train_valid_test_dataset_provider
)
return
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
return
wrapper
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
...
...
dcu_megatron/training/utils.py
0 → 100644
View file @
f3434cc7
import
torch
from
megatron.training
import
get_args
from
megatron.core
import
mpu
def
get_batch_on_this_tp_rank
(
data_iterator
):
args
=
get_args
()
def
_broadcast
(
item
):
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
if
args
.
schedule_method
==
"dualpipev"
:
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'labels'
])
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
if
args
.
schedule_method
==
"dualpipev"
:
_broadcast
(
loss_mask
)
_broadcast
(
labels
)
else
:
labels
=
None
loss_mask
=
None
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
tokens
)
_broadcast
(
position_ids
)
else
:
tokens
=
None
position_ids
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
\ No newline at end of file
pretrain_gpt.py
View file @
f3434cc7
...
...
@@ -136,7 +136,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
def
get_batch
(
data_iterator
):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if
(
not
mpu
.
is_pipeline_first_stage
())
and
(
not
mpu
.
is_pipeline_last_stage
()):
return
None
,
None
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment