Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
f3ef5e1b
Commit
f3ef5e1b
authored
Jun 12, 2025
by
dongcl
Browse files
patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919
parent
bb6ab0fb
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
226 additions
and
114 deletions
+226
-114
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
...or/features_manager/pipeline_parallel/pipeline_feature.py
+19
-2
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+2
-16
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+27
-5
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+8
-11
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+11
-9
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+39
-11
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+51
-26
dcu_megatron/legacy/model/transformer.py
dcu_megatron/legacy/model/transformer.py
+16
-19
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+7
-0
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+6
-1
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+40
-14
No files found.
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
View file @
f3ef5e1b
...
@@ -49,6 +49,9 @@ class PipelineFeature(AbstractFeature):
...
@@ -49,6 +49,9 @@ class PipelineFeature(AbstractFeature):
_allreduce_embedding_grads_wrapper
_allreduce_embedding_grads_wrapper
)
)
from
dcu_megatron.training.training
import
evaluate
from
dcu_megatron.training.training
import
evaluate
from
dcu_megatron.core.transformer.transformer_layer
import
get_transformer_layer_offset
from
dcu_megatron.training.utils
import
get_batch_on_this_tp_rank
from
dcu_megatron.training.training
import
build_train_valid_test_data_iterators_wrapper
patch_manager
.
register_patch
(
patch_manager
.
register_patch
(
'megatron.training.training.get_model'
,
get_model
)
'megatron.training.training.get_model'
,
get_model
)
...
@@ -69,6 +72,20 @@ class PipelineFeature(AbstractFeature):
...
@@ -69,6 +72,20 @@ class PipelineFeature(AbstractFeature):
patch_manager
.
register_patch
(
patch_manager
.
register_patch
(
'megatron.training.training.evaluate'
,
evaluate
)
'megatron.training.training.evaluate'
,
evaluate
)
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.get_transformer_layer_offset'
,
get_transformer_layer_offset
)
# support dualpipev, two data iterators
patch_manager
.
register_patch
(
'megatron.training.training.build_train_valid_test_data_iterators'
,
build_train_valid_test_data_iterators_wrapper
,
apply_wrapper
=
True
)
# support dualpipev, broadcast loss_mask and labels
patch_manager
.
register_patch
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
if
args
.
combined_1f1b
:
if
args
.
combined_1f1b
:
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
...
@@ -86,10 +103,10 @@ class PipelineFeature(AbstractFeature):
...
@@ -86,10 +103,10 @@ class PipelineFeature(AbstractFeature):
from
dcu_megatron.core.transformer.moe.moe_layer
import
MoELayer
from
dcu_megatron.core.transformer.moe.moe_layer
import
MoELayer
patch_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
MoEAlltoAllTokenDispatcher
)
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
TransformerLayer
)
patch_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
patch_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
GPTModel
.
build_schedule_plan
,
GPTModel
.
build_schedule_plan
,
...
...
dcu_megatron/adaptor/megatron_adaptor.py
View file @
f3ef5e1b
...
@@ -163,7 +163,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -163,7 +163,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_layer
import
get_transformer_layer_offset
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
...
@@ -190,10 +189,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -190,10 +189,6 @@ class CoreAdaptation(MegatronAdaptationABC):
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# support dualpipev
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_layer.get_transformer_layer_offset'
,
get_transformer_layer_offset
)
def
patch_core_extentions
(
self
):
def
patch_core_extentions
(
self
):
import
transformer_engine
as
te
import
transformer_engine
as
te
...
@@ -257,10 +252,10 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -257,10 +252,10 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
,
build_train_valid_test_data_iterators_wrapper
from
..training.training
import
train
from
..training.initialize
import
_set_random_seed
from
..training.initialize
import
_set_random_seed
from
..training.utils
import
get_batch_on_this_tp_rank
# add Llama3Tokenizer, QwenTokenizer, DeepSeekV2Tokenizer
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
build_tokenizer
)
# specify init_method
# specify init_method
...
@@ -278,15 +273,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -278,15 +273,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
train
)
# support dualpipev, two data iterators
MegatronAdaptation
.
register
(
'megatron.training.training.build_train_valid_test_data_iterators'
,
build_train_valid_test_data_iterators_wrapper
,
apply_wrapper
=
True
)
# support dualpipev, broadcast loss_mask and labels
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
def
patch_miscellaneous
(
self
):
def
patch_miscellaneous
(
self
):
from
..training.arguments
import
parse_args
from
..training.arguments
import
parse_args
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
f3ef5e1b
...
@@ -160,6 +160,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
...
@@ -160,6 +160,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels
:
Optional
[
int
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
model_comm_pgs
:
ModelCommProcessGroups
=
None
,
):
):
self
.
config
=
config
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
te_forward_mask_type
=
False
...
@@ -186,6 +187,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
...
@@ -186,6 +187,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
)
if
model_comm_pgs
is
None
:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs
=
ModelCommProcessGroups
(
tp
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
cp
=
get_context_parallel_group
(
check_initialized
=
False
),
hcp
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
),
)
else
:
assert
hasattr
(
model_comm_pgs
,
'tp'
),
"TEDotProductAttention model_comm_pgs must have tp pg"
assert
hasattr
(
model_comm_pgs
,
'cp'
),
"TEDotProductAttention model_comm_pgs must have cp pg"
if
cp_comm_type
==
"a2a+p2p"
:
assert
hasattr
(
model_comm_pgs
,
'hcp'
),
"TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if
is_te_min_version
(
"0.10.0"
):
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
# older version don't need attention_type
...
@@ -201,9 +222,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
...
@@ -201,9 +222,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_group"
]
=
model_comm_pgs
.
cp
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global
_ranks
(
extra_kwargs
[
"cp_global_ranks"
]
=
torch
.
distributed
.
get_process_group
_ranks
(
check_initialized
=
False
model_comm_pgs
.
cp
)
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
if
is_te_min_version
(
"1.10.0"
):
...
@@ -277,7 +298,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
...
@@ -277,7 +298,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker
=
(
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
)
,
tp_group
=
model_comm_pgs
.
tp
,
layer_number
=
layer_number
,
layer_number
=
layer_number
,
**
extra_kwargs
,
**
extra_kwargs
,
)
)
...
@@ -294,7 +315,6 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -294,7 +315,6 @@ if is_te_min_version("1.9.0.dev0"):
yet, the tp_group passed to TE will be None and must be set later
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
via set_tensor_parallel_group().
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_gemms
:
int
,
num_gemms
:
int
,
...
@@ -308,6 +328,7 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -308,6 +328,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
args
=
get_args
()
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
...
@@ -329,6 +350,7 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -329,6 +350,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
)
def
backward_dw
(
self
):
def
backward_dw
(
self
):
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
f3ef5e1b
...
@@ -288,7 +288,7 @@ class MoeAttnNode(TransformerLayerNode):
...
@@ -288,7 +288,7 @@ class MoeAttnNode(TransformerLayerNode):
pre_mlp_layernorm_output
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
probs
,
permuted_
probs
,
)
=
self
.
layer
.
_submodule_attention_router_compound_forward
(
)
=
self
.
layer
.
_submodule_attention_router_compound_forward
(
hidden_states
,
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -304,11 +304,10 @@ class MoeAttnNode(TransformerLayerNode):
...
@@ -304,11 +304,10 @@ class MoeAttnNode(TransformerLayerNode):
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
# detached here
# detached here
self
.
common_state
.
probs
=
self
.
detach
(
probs
)
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
return
permutated_local_input_tokens
return
permutated_local_input_tokens
,
permuted_probs
def
dw
(
self
):
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
...
@@ -317,26 +316,26 @@ class MoeAttnNode(TransformerLayerNode):
...
@@ -317,26 +316,26 @@ class MoeAttnNode(TransformerLayerNode):
class
MoeDispatchNode
(
TransformerLayerNode
):
class
MoeDispatchNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
permutated_local_input_tokens
):
def
forward_impl
(
self
,
permutated_local_input_tokens
,
permuted_probs
):
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
tokens_per_expert
,
global_input_tokens
=
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
global_input_tokens
,
global_probs
=
token_dispatcher
.
dispatch_all_to_all
(
self
.
common_state
.
tokens_per_expert
,
permutated_local_input_tokens
self
.
common_state
.
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
)
# release tensor not used by backward
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
# inputs.untyped_storage().resize_(0)
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
global_input_tokens
return
global_input_tokens
,
global_probs
class
MoeMlPNode
(
TransformerLayerNode
):
class
MoeMlPNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
global_input_tokens
):
def
forward_impl
(
self
,
global_input_tokens
,
global_probs
):
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
,
pre_mlp_layernorm_output
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
)
)
assert
mlp_bias
is
None
assert
mlp_bias
is
None
...
@@ -363,9 +362,7 @@ class MoeCombineNode(TransformerLayerNode):
...
@@ -363,9 +362,7 @@ class MoeCombineNode(TransformerLayerNode):
)
)
cur_stream
=
torch
.
cuda
.
current_stream
()
cur_stream
=
torch
.
cuda
.
current_stream
()
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
probs
.
record_stream
(
cur_stream
)
self
.
common_state
.
residual
=
None
self
.
common_state
.
residual
=
None
self
.
common_state
.
probs
=
None
return
output
return
output
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
f3ef5e1b
...
@@ -125,8 +125,9 @@ def gpt_model_forward(
...
@@ -125,8 +125,9 @@ def gpt_model_forward(
and
inference_context
.
is_static_batching
()
and
inference_context
.
is_static_batching
()
and
not
self
.
training
and
not
self
.
training
):
):
current_batch_size
=
input_ids
.
shape
[
0
]
sequence_len_offset
=
torch
.
tensor
(
sequence_len_offset
=
torch
.
tensor
(
[
inference_context
.
sequence_len_offset
]
*
inference_context
.
current_batch_size
,
[
inference_context
.
sequence_len_offset
]
*
current_batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
)
...
@@ -156,12 +157,6 @@ def gpt_model_forward(
...
@@ -156,12 +157,6 @@ def gpt_model_forward(
**
(
extra_block_kwargs
or
{}),
**
(
extra_block_kwargs
or
{}),
)
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
# logits and loss
output_weight
=
None
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
if
self
.
share_embeddings_and_output_weights
:
...
@@ -202,10 +197,17 @@ def gpt_model_forward(
...
@@ -202,10 +197,17 @@ def gpt_model_forward(
if
(
if
(
not
self
.
training
not
self
.
training
and
inference_context
is
not
None
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
and
inference_context
.
materialize_only_last_token_logits
):
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
if
inference_context
.
is_static_batching
():
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
else
:
# Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
logits
,
_
=
self
.
output_layer
(
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
)
...
...
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
f3ef5e1b
...
@@ -12,8 +12,11 @@ from megatron.core.transformer.moe.moe_utils import (
...
@@ -12,8 +12,11 @@ from megatron.core.transformer.moe.moe_utils import (
permute
,
permute
,
sort_chunks_by_idxs
,
sort_chunks_by_idxs
,
unpermute
,
unpermute
,
pad_routing_map
,
)
)
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
megatron.core.fp8_utils
import
get_fp8_align_size
from
megatron.core.fusions.fused_pad_routing_map
import
fused_pad_routing_map
from
dcu_megatron.core.tensor_parallel
import
all_to_all
from
dcu_megatron.core.tensor_parallel
import
all_to_all
...
@@ -101,6 +104,12 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -101,6 +104,12 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
if
self
.
config
.
moe_router_padding_for_fp8
:
pad_multiple
=
get_fp8_align_size
(
self
.
config
.
fp8_recipe
)
if
experimental_config
.
ENABLE_EXPERIMENTAL
and
self
.
config
.
moe_permute_fusion
:
self
.
routing_map
=
fused_pad_routing_map
(
self
.
routing_map
,
pad_multiple
)
else
:
self
.
routing_map
=
pad_routing_map
(
self
.
routing_map
,
pad_multiple
)
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
return
tokens_per_expert
return
tokens_per_expert
...
@@ -117,18 +126,20 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -117,18 +126,20 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
(
(
permutated_local_input_tokens
,
permutated_local_input_tokens
,
permuted_probs
,
self
.
reversed_local_input_permutation_mapping
,
self
.
reversed_local_input_permutation_mapping
,
)
=
permute
(
)
=
permute
(
hidden_states
,
hidden_states
,
routing_map
,
routing_map
,
self
.
probs
,
num_out_tokens
=
self
.
num_out_tokens
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
drop_and_pad
=
self
.
drop_and_pad
,
)
)
return
tokens_per_expert
,
permutated_local_input_tokens
return
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
def
dispatch_all_to_all
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
):
def
dispatch_all_to_all
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_ep_alltoall"
,
tokens_per_expert
"before_ep_alltoall"
,
tokens_per_expert
...
@@ -136,10 +147,13 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -136,10 +147,13 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens
=
all_to_all
(
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
)
)
global_probs
=
all_to_all
(
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
)
return
tokens_per_expert
,
global_input_tokens
return
tokens_per_expert
,
global_input_tokens
,
global_probs
def
dispatch_postprocess
(
self
,
tokens_per_expert
,
global_input_tokens
):
def
dispatch_postprocess
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
):
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
...
@@ -152,6 +166,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -152,6 +166,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
)
global_probs
=
gather_from_sequence_parallel_region
(
global_probs
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
# Permutation 2: Sort tokens by local expert.
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
...
@@ -170,16 +187,28 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -170,16 +187,28 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
.
contiguous
()
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
)
global_probs
=
(
global_probs
.
view
(
self
.
tp_size
*
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
*
global_probs
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
else
:
else
:
global_input_tokens
=
sort_chunks_by_idxs
(
global_input_tokens
,
global_probs
=
sort_chunks_by_idxs
(
global_input_tokens
,
global_input_tokens
,
self
.
num_global_tokens_per_local_expert
.
ravel
(),
self
.
num_global_tokens_per_local_expert
.
ravel
(),
self
.
sort_input_by_local_experts
,
self
.
sort_input_by_local_experts
,
probs
=
global_probs
,
fused
=
self
.
config
.
moe_permute_fusion
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
return
global_input_tokens
,
tokens_per_expert
return
global_input_tokens
,
tokens_per_expert
,
global_probs
def
token_permutation
(
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
...
@@ -207,15 +236,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -207,15 +236,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
tokens_per_expert
,
global_input_tokens
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
)
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
# Permutation 2: Sort tokens by local expert.
# Permutation 2: Sort tokens by local expert.
global_input_tokens
,
tokens_per_expert
=
self
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
)
global_input_tokens
,
tokens_per_expert
,
global_probs
=
self
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
return
global_input_tokens
,
tokens_per_expert
return
global_input_tokens
,
tokens_per_expert
,
global_probs
def
combine_preprocess
(
self
,
hidden_states
):
def
combine_preprocess
(
self
,
hidden_states
):
# Unpermutation 2: Unsort tokens by local expert.
# Unpermutation 2: Unsort tokens by local expert.
...
@@ -272,7 +301,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -272,7 +301,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
permutated_local_input_tokens
,
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
self
.
reversed_local_input_permutation_mapping
,
restore_shape
=
self
.
hidden_shape_before_permute
,
restore_shape
=
self
.
hidden_shape_before_permute
,
probs
=
self
.
probs
,
routing_map
=
self
.
routing_map
,
routing_map
=
self
.
routing_map
,
fused
=
self
.
config
.
moe_permute_fusion
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
drop_and_pad
=
self
.
drop_and_pad
,
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
f3ef5e1b
...
@@ -8,6 +8,8 @@ from megatron.core.packed_seq_params import PackedSeqParams
...
@@ -8,6 +8,8 @@ from megatron.core.packed_seq_params import PackedSeqParams
from
megatron.core.utils
import
(
from
megatron.core.utils
import
(
deprecate_inference_params
,
deprecate_inference_params
,
make_viewless_tensor
,
make_viewless_tensor
,
nvtx_range_pop
,
nvtx_range_push
,
)
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
...
@@ -15,7 +17,7 @@ from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispa
...
@@ -15,7 +17,7 @@ from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispa
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
def
get_transformer_layer_offset
(
config
:
TransformerConfig
):
def
get_transformer_layer_offset
(
config
:
TransformerConfig
,
vp_stage
:
Optional
[
int
]
=
None
):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
"""Get the index offset of current pipeline stage, given the level of pipelining."""
args
=
get_args
()
args
=
get_args
()
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
...
@@ -67,9 +69,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
...
@@ -67,9 +69,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
-
num_layers_in_last_pipeline_stage
-
num_layers_in_last_pipeline_stage
)
)
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
if
(
vp_size
:
=
config
.
virtual_pipeline_model_parallel_size
)
is
not
None
:
vp_rank
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
assert
(
vp_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
vp_stage
is
not
None
),
"vp_stage must be provided if virtual pipeline model parallel size is set"
# Calculate number of layers in each virtual model chunk
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
# If the num_layers_in_first_pipeline_stage and
...
@@ -100,10 +103,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
...
@@ -100,10 +103,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Calculate the layer offset with interleaved uneven pipeline parallelism
# Calculate the layer offset with interleaved uneven pipeline parallelism
if
pipeline_rank
==
0
:
if
pipeline_rank
==
0
:
offset
=
vp_
rank
*
total_virtual_chunks
offset
=
vp_
stage
*
total_virtual_chunks
else
:
else
:
offset
=
(
offset
=
(
vp_
rank
*
total_virtual_chunks
vp_
stage
*
total_virtual_chunks
+
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+
(
pipeline_rank
-
1
)
+
(
pipeline_rank
-
1
)
*
(
*
(
...
@@ -151,20 +154,23 @@ def get_transformer_layer_offset(config: TransformerConfig):
...
@@ -151,20 +154,23 @@ def get_transformer_layer_offset(config: TransformerConfig):
if
args
.
schedule_method
==
'dualpipev'
:
if
args
.
schedule_method
==
'dualpipev'
:
num_layers_per_pipeline_rank
=
num_layers_per_pipeline_rank
//
2
num_layers_per_pipeline_rank
=
num_layers_per_pipeline_rank
//
2
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
if
(
vp_size
:
=
config
.
virtual_pipeline_model_parallel_size
)
is
not
None
:
vp_rank
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
assert
(
vp_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
vp_stage
is
not
None
),
"vp_stage must be provided if virtual pipeline model parallel size is set"
num_layers_per_virtual_rank
=
num_layers_per_pipeline_rank
//
vp_size
num_layers_per_virtual_rank
=
num_layers_per_pipeline_rank
//
vp_size
total_virtual_chunks
=
num_layers
//
vp_size
total_virtual_chunks
=
num_layers
//
vp_size
offset
=
vp_
rank
*
total_virtual_chunks
+
(
offset
=
vp_
stage
*
total_virtual_chunks
+
(
pipeline_rank
*
num_layers_per_virtual_rank
pipeline_rank
*
num_layers_per_virtual_rank
)
)
# Reduce the offset of embedding layer from the total layer number
# Reduce the offset of embedding layer from the total layer number
if
(
if
(
config
.
account_for_embedding_in_pipeline_split
config
.
account_for_embedding_in_pipeline_split
and
not
parallel_state
.
is_pipeline_first_stage
()
and
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
,
vp_stage
=
vp_stage
)
):
):
offset
-=
1
offset
-=
1
else
:
else
:
...
@@ -176,7 +182,9 @@ def get_transformer_layer_offset(config: TransformerConfig):
...
@@ -176,7 +182,9 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Reduce the offset of embedding layer from the total layer number
# Reduce the offset of embedding layer from the total layer number
if
(
if
(
config
.
account_for_embedding_in_pipeline_split
config
.
account_for_embedding_in_pipeline_split
and
not
parallel_state
.
is_pipeline_first_stage
()
and
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
,
vp_stage
=
vp_stage
)
):
):
offset
-=
1
offset
-=
1
else
:
else
:
...
@@ -188,9 +196,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -188,9 +196,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
context
:
Optional
[
Tensor
]
=
None
,
context
:
Optional
[
Tensor
]
=
None
,
context_mask
:
Optional
[
Tensor
]
=
None
,
context_mask
:
Optional
[
Tensor
]
=
None
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
...
@@ -208,9 +216,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -208,9 +216,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
):
):
return
super
().
forward
(
return
super
().
forward
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context
=
context
,
context_mask
=
context_mask
,
context_mask
=
context_mask
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
rotary_pos_sin
=
rotary_pos_sin
,
...
@@ -226,7 +234,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -226,7 +234,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
_
,
permuted_probs
,
)
=
self
.
_submodule_attention_router_compound_forward
(
)
=
self
.
_submodule_attention_router_compound_forward
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
@@ -240,14 +248,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -240,14 +248,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params
=
inference_params
,
inference_params
=
inference_params
,
)
)
(
tokens_per_expert
,
global_input_tokens
)
=
self
.
_submodule_dispatch_forward
(
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
=
self
.
_submodule_dispatch_forward
(
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
permuted_probs
,
)
)
(
expert_output
,
shared_expert_output
,
mlp_bias
)
=
self
.
_submodule_moe_forward
(
(
expert_output
,
shared_expert_output
,
mlp_bias
)
=
self
.
_submodule_moe_forward
(
tokens_per_expert
,
tokens_per_expert
,
global_input_tokens
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
pre_mlp_layernorm_output
)
)
...
@@ -292,6 +302,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -292,6 +302,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
nvtx_range_push
(
suffix
=
"self_attention"
)
attention_output_with_bias
=
self
.
self_attention
(
attention_output_with_bias
=
self
.
self_attention
(
input_layernorm_output
,
input_layernorm_output
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -303,6 +314,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -303,6 +314,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
packed_seq_params
=
packed_seq_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
sequence_len_offset
=
sequence_len_offset
,
)
)
nvtx_range_pop
(
suffix
=
"self_attention"
)
if
self
.
recompute_input_layernorm
:
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# discard the output of the input layernorm and register the recompute
...
@@ -313,10 +325,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -313,10 +325,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push
(
suffix
=
"self_attn_bda"
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
)
)
nvtx_range_pop
(
suffix
=
"self_attn_bda"
)
return
hidden_states
return
hidden_states
...
@@ -363,7 +377,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -363,7 +377,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
pre_mlp_layernorm_output
,
probs
,
routing_map
pre_mlp_layernorm_output
,
probs
,
routing_map
)
)
tokens_per_expert
,
permutated_local_input_tokens
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
tokens_per_expert
pre_mlp_layernorm_output
,
routing_map
,
tokens_per_expert
)
)
...
@@ -372,18 +386,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -372,18 +386,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
probs
,
permuted_
probs
,
]
]
return
tuple
(
outputs
)
return
tuple
(
outputs
)
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
):
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
"""
"""
Dispatches tokens to the appropriate experts based on the router output.
Dispatches tokens to the appropriate experts based on the router output.
"""
"""
tokens_per_expert
,
global_input_tokens
=
self
.
mlp
.
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
)
return
[
tokens_per_expert
,
global_input_tokens
]
return
[
tokens_per_expert
,
global_input_tokens
,
global_probs
]
def
_submodule_dense_forward
(
self
,
hidden_states
):
def
_submodule_dense_forward
(
self
,
hidden_states
):
residual
=
hidden_states
residual
=
hidden_states
...
@@ -399,18 +413,20 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -399,18 +413,20 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
return
output
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
pre_mlp_layernorm_output
):
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
):
"""
"""
Performs a forward pass for the MLP submodule, including both expert-based
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
and optional shared-expert computations.
"""
"""
shared_expert_output
=
None
shared_expert_output
=
None
(
dispatched_input
,
tokens_per_expert
)
=
(
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
=
(
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
)
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
)
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
return
expert_output
,
shared_expert_output
,
mlp_bias
...
@@ -438,10 +454,19 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -438,10 +454,19 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push
(
suffix
=
"mlp_bda"
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
)
nvtx_range_pop
(
suffix
=
"mlp_bda"
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
)
...
...
dcu_megatron/legacy/model/transformer.py
View file @
f3ef5e1b
...
@@ -6,6 +6,7 @@ from functools import wraps
...
@@ -6,6 +6,7 @@ from functools import wraps
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.legacy.model.enums
import
AttnType
from
megatron.legacy.model.enums
import
AttnType
from
megatron.core.utils
import
deprecate_inference_params
from
megatron.core.models.common.embeddings
import
apply_rotary_pos_emb
from
megatron.core.models.common.embeddings
import
apply_rotary_pos_emb
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.module
import
MegatronModule
...
@@ -86,26 +87,21 @@ def parallel_attention_init_wrapper(fn):
...
@@ -86,26 +87,21 @@ def parallel_attention_init_wrapper(fn):
return
wrapper
return
wrapper
class
ParallelAttentionPatch
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_
params
=
None
,
encoder_output
=
None
,
inference_
context
=
None
,
rotary_pos_emb
=
None
):
rotary_pos_emb
=
None
,
*
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# =================================================
# =================================================
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference.
# =================================================
# =================================================
is_first_step
=
False
is_first_step
=
False
if
inference_
params
:
if
inference_
context
:
if
self
.
layer_number
not
in
inference_
params
.
key_value_memory_dict
:
if
self
.
layer_number
not
in
inference_
context
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_
params
.
max_sequence_length
inf_max_seq_len
=
inference_
context
.
max_sequence_length
inf_max_batch_size
=
inference_
params
.
max_batch_size
inf_max_batch_size
=
inference_
context
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
self
.
num_query_groups_per_partition
)
...
@@ -113,12 +109,12 @@ class ParallelAttentionPatch(MegatronModule):
...
@@ -113,12 +109,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len
,
inf_max_batch_size
,
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
self
.
num_query_groups_per_partition
)
inference_
params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_
context
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
is_first_step
=
True
else
:
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_key_memory
,
inference_value_memory
=
\
inference_
params
.
key_value_memory_dict
[
self
.
layer_number
]
inference_
context
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
...
@@ -188,13 +184,14 @@ class ParallelAttentionPatch(MegatronModule):
...
@@ -188,13 +184,14 @@ class ParallelAttentionPatch(MegatronModule):
else
:
else
:
rotary_pos_emb
=
((
rotary_pos_emb
,)
*
2
)
rotary_pos_emb
=
((
rotary_pos_emb
,)
*
2
)
if
inference_
params
:
if
inference_
context
:
batch_start
=
inference_
params
.
batch_size_offset
batch_start
=
inference_
context
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_
params
.
sequence_len_offset
sequence_start
=
inference_
context
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
),
(
"Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length."
)
# Copy key and values.
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
batch_start
:
batch_end
,
...]
=
key_layer
...
...
dcu_megatron/training/arguments.py
View file @
f3ef5e1b
...
@@ -3,6 +3,7 @@ import argparse
...
@@ -3,6 +3,7 @@ import argparse
from
typing
import
Union
from
typing
import
Union
from
megatron.training.arguments
import
add_megatron_arguments
from
megatron.training.arguments
import
add_megatron_arguments
from
megatron.core.msc_utils
import
MultiStorageClientFeature
from
dcu_megatron.adaptor.features_manager
import
ADAPTOR_FEATURES
from
dcu_megatron.adaptor.features_manager
import
ADAPTOR_FEATURES
...
@@ -65,6 +66,12 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -65,6 +66,12 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# args.rank = int(os.getenv('RANK', '0'))
# args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if
not
args
.
enable_msc
:
MultiStorageClientFeature
.
disable
()
assert
MultiStorageClientFeature
.
is_enabled
()
is
False
print
(
'WARNING: The MSC feature is disabled.'
)
return
args
return
args
...
...
dcu_megatron/training/initialize.py
View file @
f3ef5e1b
...
@@ -8,6 +8,7 @@ from datetime import timedelta
...
@@ -8,6 +8,7 @@ from datetime import timedelta
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.training
import
inprocess_restart
def
_compile_dependencies
():
def
_compile_dependencies
():
...
@@ -76,7 +77,7 @@ def _compile_dependencies():
...
@@ -76,7 +77,7 @@ def _compile_dependencies():
)
)
def
_initialize_distributed
(
get_embedding_ranks
,
get_position_embedding_ranks
):
def
_initialize_distributed
(
get_embedding_ranks
,
get_position_embedding_ranks
,
store
):
"""Initialize torch.distributed and core model parallel."""
"""Initialize torch.distributed and core model parallel."""
args
=
get_args
()
args
=
get_args
()
...
@@ -109,6 +110,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -109,6 +110,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
# Call the init process
init_process_group_kwargs
=
{
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
'backend'
:
args
.
distributed_backend
,
'store'
:
store
,
'world_size'
:
args
.
world_size
,
'world_size'
:
args
.
world_size
,
'rank'
:
args
.
rank
,
'rank'
:
args
.
rank
,
'init_method'
:
args
.
dist_url
,
'init_method'
:
args
.
dist_url
,
...
@@ -116,6 +118,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -116,6 +118,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
}
}
torch
.
distributed
.
init_process_group
(
**
init_process_group_kwargs
)
torch
.
distributed
.
init_process_group
(
**
init_process_group_kwargs
)
inprocess_restart
.
maybe_force_nccl_backend_init
(
device_id
)
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
@@ -129,6 +132,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -129,6 +132,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args
.
virtual_pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
,
args
.
pipeline_model_parallel_split_rank
,
pipeline_model_parallel_comm_backend
=
args
.
pipeline_model_parallel_comm_backend
,
pipeline_model_parallel_comm_backend
=
args
.
pipeline_model_parallel_comm_backend
,
use_sharp
=
args
.
use_sharp
,
context_parallel_size
=
args
.
context_parallel_size
,
context_parallel_size
=
args
.
context_parallel_size
,
hierarchical_context_parallel_sizes
=
args
.
hierarchical_context_parallel_sizes
,
hierarchical_context_parallel_sizes
=
args
.
hierarchical_context_parallel_sizes
,
expert_model_parallel_size
=
args
.
expert_model_parallel_size
,
expert_model_parallel_size
=
args
.
expert_model_parallel_size
,
...
@@ -142,6 +146,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -142,6 +146,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
get_embedding_ranks
=
get_embedding_ranks
,
get_embedding_ranks
=
get_embedding_ranks
,
get_position_embedding_ranks
=
get_position_embedding_ranks
,
get_position_embedding_ranks
=
get_position_embedding_ranks
,
create_gloo_process_groups
=
args
.
enable_gloo_process_groups
,
create_gloo_process_groups
=
args
.
enable_gloo_process_groups
,
high_priority_stream_groups
=
args
.
high_priority_stream_groups
,
)
)
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
print
(
...
...
dcu_megatron/training/utils.py
View file @
f3ef5e1b
...
@@ -19,7 +19,11 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -19,7 +19,11 @@ def get_batch_on_this_tp_rank(data_iterator):
def
_broadcast
(
item
):
def
_broadcast
(
item
):
if
item
is
not
None
:
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
(),
)
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
...
@@ -29,11 +33,15 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -29,11 +33,15 @@ def get_batch_on_this_tp_rank(data_iterator):
data
=
None
data
=
None
batch
=
{
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
(
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
)
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
),
}
}
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
...
@@ -64,16 +72,34 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -64,16 +72,34 @@ def get_batch_on_this_tp_rank(data_iterator):
else
:
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens
=
torch
.
empty
(
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
(
args
.
micro_batch_size
,
args
.
seq_length
),
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
(),
)
labels
=
torch
.
empty
(
(
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
(),
)
loss_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
(),
)
if
args
.
create_attention_mask_in_dataloader
:
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
(),
)
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
position_ids
=
torch
.
empty
(
(
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
(),
)
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
tokens
)
...
@@ -117,4 +143,4 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -117,4 +143,4 @@ def get_batch_on_this_tp_rank(data_iterator):
'position_ids'
:
position_ids
'position_ids'
:
position_ids
}
}
return
batch
return
batch
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment