Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
770fa304
Commit
770fa304
authored
Apr 25, 2025
by
dongcl
Browse files
修改mtp
parent
8096abd4
Changes
44
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4876 additions
and
476 deletions
+4876
-476
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/experts.py
...tron/core/pipeline_parallel/fb_overlap/modules/experts.py
+121
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/layers.py
...atron/core/pipeline_parallel/fb_overlap/modules/layers.py
+761
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/token_dispatcher.py
.../pipeline_parallel/fb_overlap/modules/token_dispatcher.py
+208
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/utils.py
...gatron/core/pipeline_parallel/fb_overlap/modules/utils.py
+133
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/weight_grad_store.py
...pipeline_parallel/fb_overlap/modules/weight_grad_store.py
+159
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/__init__.py
...re/pipeline_parallel/fb_overlap/overlap_funcs/__init__.py
+5
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/bwd.py
...on/core/pipeline_parallel/fb_overlap/overlap_funcs/bwd.py
+131
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/fwd.py
...on/core/pipeline_parallel/fb_overlap/overlap_funcs/fwd.py
+324
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/fwdbwd.py
...core/pipeline_parallel/fb_overlap/overlap_funcs/fwdbwd.py
+1007
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/transformer_block.py
...on/core/pipeline_parallel/fb_overlap/transformer_block.py
+208
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/transformer_layer.py
...on/core/pipeline_parallel/fb_overlap/transformer_layer.py
+175
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/vpp_schedules.py
...gatron/core/pipeline_parallel/fb_overlap/vpp_schedules.py
+826
-0
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+52
-0
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+2
-124
dcu_megatron/core/transformer/mtp/mtp_spec.py
dcu_megatron/core/transformer/mtp/mtp_spec.py
+0
-51
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+0
-285
dcu_megatron/core/transformer/multi_token_prediction.py
dcu_megatron/core/transformer/multi_token_prediction.py
+737
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+26
-13
No files found.
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/experts.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
einops
import
rearrange
import
torch
from
megatron.training
import
get_args
from
mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store
import
WeightGradStore
from
mindspeed.ops.gmm
import
GMMFunction
from
mindspeed.model.transformer
import
should_recompute_activation
from
mindspeed.ops.npu_groupmatmul_add
import
npu_groupmatmul_add_fp32
def
get_gmm_weight_grad
(
inputs
,
grad_out
,
group_list
,
group_list_data_type
,
weight_param
,
weight_tensor
):
if
WeightGradStore
.
is_decoupleBlock
:
WeightGradStore
.
put
(
[
inputs
,
group_list
,
group_list_data_type
],
grad_out
,
weight_param
,
sequence_parallel
=
False
,
in_row
=
False
,
)
if
hasattr
(
weight_param
,
'grad_added_to_main_grad'
)
and
get_args
().
overlap_grad_reduce
:
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
shape
=
list
(
weight_tensor
.
shape
)
shape
[
1
],
shape
[
2
]
=
shape
[
2
],
shape
[
1
]
weight_param
.
skip_grad_accum
=
True
grad_weights
=
None
else
:
if
get_args
().
gemm_gradient_accumulation_fusion
:
npu_groupmatmul_add_fp32
(
inputs
,
grad_out
,
group_list
,
weight_param
.
main_grad
)
if
hasattr
(
weight_param
,
'grad_added_to_main_grad'
):
shape
=
list
(
weight_tensor
.
shape
)
shape
[
1
],
shape
[
2
]
=
shape
[
2
],
shape
[
1
]
if
getattr
(
weight_tensor
,
'zero_out_wgrad'
,
False
):
grad_weights
=
torch
.
zeros
(
shape
,
dtype
=
inputs
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
grad_weights
=
torch
.
empty
(
shape
,
dtype
=
inputs
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
weight_param
.
grad_added_to_main_grad
=
True
else
:
grad_weights
=
None
else
:
grad_weights
=
GMMFunction
.
builder
.
load
().
npu_gmm
([
inputs
.
t
()],
[
grad_out
],
[],
group_list
,
2
,
group_list_data_type
)[
0
]
return
grad_weights
class
GroupedMatmulWithWeightGradDetach
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
weight_tensor
,
weight_param
,
group_list
,
in_row
=
False
):
mm_out
=
GMMFunction
.
builder
.
load
().
npu_gmm
([
inputs
],
[
weight_tensor
],
[],
group_list
,
0
,
0
)[
0
]
ctx
.
save_for_backward
(
inputs
,
weight_tensor
,
group_list
)
ctx
.
weight_param
=
weight_param
ctx
.
in_row
=
in_row
return
mm_out
@
staticmethod
def
backward
(
ctx
,
*
grad_outs
):
grad_out
=
grad_outs
[
0
]
inputs
,
weight_tensor
,
group_list
=
ctx
.
saved_tensors
weight_param
=
ctx
.
weight_param
weight_tensor
=
rearrange
(
weight_tensor
,
'n h f -> n f h'
)
grad_inputs
=
\
GMMFunction
.
builder
.
load
().
npu_gmm
([
grad_out
],
[
weight_tensor
],
[],
group_list
,
0
,
0
)[
0
]
grad_weights
=
get_gmm_weight_grad
(
inputs
,
grad_out
,
group_list
,
0
,
weight_param
,
weight_tensor
)
return
grad_inputs
,
grad_weights
,
None
,
None
,
None
def
npu_gmm_with_detach
(
inputs
,
weight_tensor
,
weight_param
,
bias
=
None
,
group_list
=
None
):
return
GroupedMatmulWithWeightGradDetach
.
apply
(
inputs
,
weight_tensor
,
weight_param
,
group_list
)
def
group_mlp_forward_detach
(
self
,
permuted_local_hidden_states
,
tokens_per_expert
):
args
=
get_args
()
is_recompute_activation
=
args
.
moe_zero_memory
==
'level0'
or
should_recompute_activation
(
self
.
layer_number
)
if
permuted_local_hidden_states
.
nelement
()
!=
0
:
group_list
=
torch
.
cumsum
(
tokens_per_expert
,
dim
=
0
)
w1
=
self
.
weight1
.
view
(
self
.
num_local_experts
,
self
.
config
.
hidden_size
,
-
1
)
w2
=
self
.
weight2
.
view
(
self
.
num_local_experts
,
-
1
,
self
.
config
.
hidden_size
)
fc1_output
=
npu_gmm_with_detach
(
permuted_local_hidden_states
,
w1
,
self
.
weight1
,
bias
=
None
,
group_list
=
group_list
)
intermediate_parallel
=
self
.
activation_func
(
fc1_output
)
fc2_output
=
npu_gmm_with_detach
(
intermediate_parallel
,
w2
,
self
.
weight2
,
bias
=
None
,
group_list
=
group_list
)
if
is_recompute_activation
:
intermediate_parallel
.
untyped_storage
().
resize_
(
0
)
else
:
# No token is allocated for local experts.
assert
torch
.
count_nonzero
(
tokens_per_expert
)
==
0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1
=
self
.
weight1
.
view
(
self
.
config
.
hidden_size
,
-
1
)
w2
=
self
.
weight2
.
view
(
-
1
,
self
.
config
.
hidden_size
)
fc1_output
=
torch
.
matmul
(
permuted_local_hidden_states
,
w1
)
intermediate_parallel
=
self
.
activation_func
(
fc1_output
)
fc2_output
=
torch
.
matmul
(
intermediate_parallel
,
w2
)
if
is_recompute_activation
:
intermediate_parallel
.
untyped_storage
().
resize_
(
0
)
return
(
fc2_output
,
fc1_output
,
intermediate_parallel
),
None
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/layers.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import
os
import
warnings
from
typing
import
Any
,
Callable
,
List
,
Optional
import
torch
import
torch.distributed
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.parameter
import
Parameter
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
linear_with_grad_accumulation_and_async_allreduce
,
linear_with_frozen_weight
)
from
megatron.core.tensor_parallel.mappings
import
(
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
scatter_to_tensor_model_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
,
divide
,
split_tensor_along_last_dim
from
megatron.core.utils
import
(
make_tp_sharded_tensor_for_checkpoint
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.transformer.utils
import
make_sharded_tensors_for_checkpoint
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.training
import
get_args
from
mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store
import
WeightGradStore
def
linear_backward_wgrad_detach
(
ctx
,
grad_output
):
input_
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
grad_output_buffer
.
append
(
grad_output
)
wgrad_compute
=
False
if
wgrad_compute
:
if
ctx
.
sequence_parallel
and
not
WeightGradStore
.
is_decoupleBlock
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input_
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input_
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input
=
all_gather_buffer
else
:
total_input
=
input_
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
and
wgrad_compute
and
not
WeightGradStore
.
is_decoupleBlock
:
handle
.
wait
()
if
wgrad_compute
and
not
WeightGradStore
.
is_decoupleBlock
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
total_input
)
if
ctx
.
allreduce_dgrad
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
allreduce_dgrad
dim_size
=
list
(
input_
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if
WeightGradStore
.
is_decoupleBlock
:
# TODO: remove clone under MLA setting
WeightGradStore
.
put
(
total_input
.
clone
().
detach
(),
grad_output
.
clone
().
detach
(),
weight
,
ctx
.
sequence_parallel
,
in_row
=
not
ctx
.
sequence_parallel
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
)
and
get_args
().
overlap_grad_reduce
:
weight
.
skip_grad_accum
=
True
grad_weight
=
None
else
:
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
from
mindspeed.ops.npu_matmul_add
import
npu_matmul_add_fp32
npu_matmul_add_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if
getattr
(
weight
,
'zero_out_wgrad'
,
False
):
grad_weight
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
grad_weight
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
sequence_parallel
:
handle
.
wait
()
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
if
ctx
.
allreduce_dgrad
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
,
grad_output_buffer
,
shared_expert
,
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
shared_expert
=
shared_expert
if
sequence_parallel
:
if
shared_expert
:
from
mindspeed.core.transformer.moe.moe_utils
import
AG_SHARED_EXPERTS_INPUTS
ag_shared_experts_inputs
=
AG_SHARED_EXPERTS_INPUTS
.
pop
(
0
)
if
isinstance
(
ag_shared_experts_inputs
,
tuple
):
ag_shared_experts_inputs
,
handle
=
ag_shared_experts_inputs
handle
.
wait
()
total_input
=
ag_shared_experts_inputs
else
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
()
)
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
grad_output_buffer
.
append
(
grad_output
)
wgrad_compute
=
False
if
wgrad_compute
:
if
ctx
.
sequence_parallel
and
not
WeightGradStore
.
is_decoupleBlock
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
and
wgrad_compute
and
not
WeightGradStore
.
is_decoupleBlock
:
handle
.
wait
()
if
wgrad_compute
and
not
WeightGradStore
.
is_decoupleBlock
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
total_input
)
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if
WeightGradStore
.
is_decoupleBlock
:
# TODO: remove clone under MLA setting
WeightGradStore
.
put
(
total_input
.
clone
().
detach
(),
grad_output
.
clone
().
detach
(),
weight
,
ctx
.
sequence_parallel
,
in_row
=
not
ctx
.
sequence_parallel
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
)
and
get_args
().
overlap_grad_reduce
:
weight
.
skip_grad_accum
=
True
grad_weight
=
None
else
:
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
from
mindspeed.ops.npu_matmul_add
import
npu_matmul_add_fp32
npu_matmul_add_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if
getattr
(
weight
,
'zero_out_wgrad'
,
False
):
grad_weight
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
grad_weight
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
sequence_parallel
:
handle
.
wait
()
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
def
linear_with_grad_accumulation_and_async_allreduce
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel
:
bool
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
shared_expert
:
bool
=
False
)
->
torch
.
Tensor
:
args
=
[
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
,
grad_output_buffer
,
shared_expert
,
]
if
not
linear_with_grad_accumulation_and_async_allreduce
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel
:
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
if
async_grad_allreduce
:
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_size
,
output_size
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
=
True
,
gather_output
=
False
,
stride
=
1
,
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
embedding_activation_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
shared_expert
:
bool
=
False
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
is_expert
=
is_expert
self
.
expert_parallel
=
config
.
expert_model_parallel_size
>
1
self
.
embedding_activation_buffer
=
embedding_activation_buffer
self
.
grad_output_buffer
=
grad_output_buffer
self
.
config
=
config
self
.
shared_expert
=
shared_expert
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if
not
skip_weight_param_allocation
:
if
config
.
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
config
.
params_dtype
)
)
if
config
.
perform_initialization
:
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size_per_partition
,
0
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
,
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
,
expert_parallel
=
(
self
.
is_expert
and
self
.
expert_parallel
),
)
setattr
(
self
.
weight
,
'allreduce'
,
not
(
self
.
is_expert
and
self
.
expert_parallel
))
else
:
self
.
weight
=
None
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
config
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
sequence_parallel
=
config
.
sequence_parallel
if
self
.
sequence_parallel
and
world_size
<=
1
:
self
.
sequence_parallel
=
False
self
.
gradient_accumulation_fusion
=
config
.
gradient_accumulation_fusion
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel
:
raise
RuntimeError
(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
)
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
self
.
explicit_expert_comm
=
self
.
is_expert
and
(
self
.
sequence_parallel
or
self
.
expert_parallel
)
# Hook adding a default empty _extra_state for state dict
self
.
_register_load_state_dict_pre_hook
(
lambda
state_dict
,
prefix
,
*
args
,
**
kwargs
:
state_dict
.
setdefault
(
f
'
{
prefix
}
_extra_state'
)
)
def
forward
(
self
,
input_
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
- output
- bias
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
else
:
# Check the weight passed in is the correct shape
expected_shape
=
(
self
.
output_size_per_partition
,
self
.
input_size
)
if
weight
.
shape
!=
expected_shape
:
raise
RuntimeError
(
f
"supplied weight's shape is
{
tuple
(
weight
.
shape
)
}
, "
f
"not
{
expected_shape
}
as expected"
)
if
self
.
config
.
_cpu_offloading_context
is
not
None
:
if
self
.
config
.
_cpu_offloading_context
.
inside_context
==
True
:
assert
(
self
.
config
.
cpu_offloading
==
False
),
"CPU Offloading cannot be enabled while using non-TE modules"
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
(
self
.
async_tensor_model_parallel_allreduce
or
self
.
sequence_parallel
or
self
.
explicit_expert_comm
):
input_parallel
=
input_
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
if
self
.
config
.
defer_embedding_wgrad_compute
:
self
.
embedding_activation_buffer
.
append
(
input_parallel
)
# Matrix multiply.
if
not
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
weight
,
bias
=
bias
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
if
self
.
explicit_expert_comm
else
self
.
async_tensor_model_parallel_allreduce
,
sequence_parallel
=
False
if
self
.
explicit_expert_comm
else
self
.
sequence_parallel
,
grad_output_buffer
=
self
.
grad_output_buffer
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
shared_expert
=
self
.
shared_expert
)
if
self
.
gather_output
:
# All-gather across the partitions.
assert
not
self
.
sequence_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
""" Sharding along axis 0, bias sharded """
state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
return
make_sharded_tensors_for_checkpoint
(
state_dict
,
prefix
,
{
'weight'
:
0
,
'bias'
:
0
},
sharded_offsets
)
def
set_extra_state
(
self
,
state
:
Any
):
""" Extra state is ignored """
def
get_extra_state
(
self
)
->
None
:
""" Keep compatibility with TE state dict. """
return
None
class
RowParallelLinear
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
input_is_parallel
:
bool
,
skip_bias_add
:
bool
,
stride
:
int
=
1
,
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
shared_expert
:
bool
=
False
):
super
(
RowParallelLinear
,
self
).
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
input_is_parallel
=
input_is_parallel
# Divide the weight matrix along the last dimension.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
config
=
config
self
.
is_expert
=
is_expert
self
.
expert_parallel
=
config
.
expert_model_parallel_size
>
1
self
.
gradient_accumulation_fusion
=
config
.
gradient_accumulation_fusion
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
shared_expert
=
shared_expert
if
self
.
sequence_parallel
and
not
self
.
input_is_parallel
:
raise
RuntimeError
(
"To enable `sequence_parallel`, `input_is_parallel` must be `True`"
)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if
config
.
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
config
.
params_dtype
)
)
if
config
.
perform_initialization
:
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
input_size_per_partition
,
1
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
,
params_dtype
=
config
.
params_dtype
,
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
,
expert_parallel
=
(
self
.
is_expert
and
self
.
expert_parallel
),
)
setattr
(
self
.
weight
,
'allreduce'
,
not
(
self
.
is_expert
and
self
.
expert_parallel
))
if
bias
:
if
config
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
config
.
params_dtype
))
else
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
setattr
(
self
.
bias
,
'allreduce'
,
not
(
self
.
is_expert
and
self
.
expert_parallel
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
self
.
explicit_expert_comm
=
self
.
is_expert
and
(
self
.
sequence_parallel
or
self
.
expert_parallel
)
# Hook adding a default empty _extra_state for state dict
self
.
_register_load_state_dict_pre_hook
(
lambda
state_dict
,
prefix
,
*
args
,
**
kwargs
:
state_dict
.
setdefault
(
f
'
{
prefix
}
_extra_state'
)
)
def
forward
(
self
,
input_
):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
if
self
.
config
.
_cpu_offloading_context
is
not
None
:
if
self
.
config
.
_cpu_offloading_context
.
inside_context
==
True
:
assert
(
self
.
config
.
cpu_offloading
==
False
),
"CPU Offloading cannot be enabled while using non-TE modules"
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
if
not
self
.
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
,
sequence_parallel
=
False
,
)
# All-reduce across all the partitions.
if
self
.
explicit_expert_comm
or
self
.
shared_expert
:
assert
self
.
skip_bias_add
output_
=
output_parallel
elif
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
(
output_
+
self
.
bias
)
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
""" Sharding along axis 1, bias not sharded """
state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
return
make_sharded_tensors_for_checkpoint
(
state_dict
,
prefix
,
{
'weight'
:
1
},
sharded_offsets
)
def
set_extra_state
(
self
,
state
:
Any
):
""" Extra state is ignored """
def
get_extra_state
(
self
)
->
None
:
""" Keep compatibility with TE state dict. """
return
None
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/token_dispatcher.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.core.transformer.moe.moe_utils
import
permute
,
unpermute
from
megatron.core.tensor_parallel.mappings
import
_gather_along_first_dim_expert_parallel
from
megatron.core.utils
import
make_viewless_tensor
from
megatron.training
import
get_args
from
mindspeed.core.transformer.moe.unpermute_without_activation
import
UnpermuteWithoutActivation
def
preprocess
(
self
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# use 0.7.0 implement for better performance
num_local_tokens_per_expert
=
torch
.
histc
(
indices
,
bins
=
self
.
num_experts
,
min
=
0
,
max
=
self
.
num_experts
)
ep_size
=
self
.
config
.
expert_model_parallel_size
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
tp_extended_ep_size
=
ep_size
*
tp_size
if
self
.
drop_and_pad
:
self
.
capacity
=
self
.
probs
.
size
(
1
)
num_tokens_per_local_expert
=
torch
.
full
(
(
self
.
num_local_experts
,),
self
.
capacity
*
self
.
ep_size
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
return
num_tokens_per_local_expert
elif
self
.
config
.
moe_expert_capacity_factor
is
not
None
:
# Token drop but no pad. A synchronization is needed before the first
# permutation to get the `num_out_tokens` CPU value.
self
.
num_out_tokens
=
num_local_tokens_per_expert
.
sum
().
to
(
torch
.
device
(
"cpu"
),
non_blocking
=
True
)
self
.
cuda_sync_point
=
"before_permutation_1"
elif
tp_extended_ep_size
>
1
:
# Token dropless and enable ep. A synchronization is needed before expert parallel
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
self
.
cuda_sync_point
=
"before_ep_alltoall"
else
:
# Token dropless and no ep. A synchronization is needed before the token_permutation()
# function returns to get the `tokens_per_expert` CPU value.
self
.
cuda_sync_point
=
"before_finish"
if
tp_extended_ep_size
>
1
:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self
.
input_splits
=
(
num_local_tokens_per_expert
.
reshape
(
tp_extended_ep_size
,
self
.
num_local_experts
)
.
sum
(
axis
=
1
)
.
to
(
torch
.
device
(
"cpu"
),
non_blocking
=
True
)
.
numpy
()
)
num_global_tokens_per_expert
=
tensor_parallel
.
gather_from_sequence_parallel_region_to_moe
(
num_local_tokens_per_expert
).
reshape
(
tp_extended_ep_size
,
self
.
num_experts
)
self
.
num_global_tokens_per_local_expert
=
num_global_tokens_per_expert
[
:,
self
.
local_expert_indices
[
0
]
:
self
.
local_expert_indices
[
-
1
]
+
1
]
self
.
output_splits
=
(
self
.
num_global_tokens_per_local_expert
.
sum
(
axis
=-
1
)
.
to
(
torch
.
device
(
"cpu"
),
non_blocking
=
True
)
.
numpy
()
)
num_tokens_per_local_expert
=
self
.
num_global_tokens_per_local_expert
.
sum
(
axis
=
0
)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else
:
self
.
num_global_tokens_per_local_expert
=
num_local_tokens_per_expert
.
reshape
(
-
1
,
self
.
num_experts
)
num_tokens_per_local_expert
=
num_local_tokens_per_expert
if
self
.
num_local_experts
>
1
:
# No further synchronization is needed because torch.repeat_interleave() calls stream
# synchronization internally when the `output_size` parameter is not provided.
self
.
cuda_sync_point
=
"no_sync"
self
.
global_input_tokens_local_experts_indices
=
torch
.
repeat_interleave
(
self
.
expert_ids_per_ep_rank
,
self
.
num_global_tokens_per_local_expert
.
ravel
()
)
return
num_tokens_per_local_expert
def
alltoall_token_perm1
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
):
self
.
hidden_shape
=
hidden_states
.
shape
self
.
probs
=
probs
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
indices
.
dim
()
==
2
,
"Expected 2D tensor for indices"
tokens_per_expert
=
preprocess
(
self
,
indices
)
# Flatten the input tensor
# hidden_states: [S/TP, B, H] -> [S*B/TP, H]
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
# Permutation 1: input to AlltoAll input
self
.
hiddden_shape_before_permute
=
hidden_states
.
shape
if
self
.
cuda_sync_point
==
"before_permutation_1"
:
torch
.
cuda
.
current_stream
().
synchronize
()
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
hidden_states
,
indices
,
num_out_tokens
=
self
.
num_out_tokens
,
padded_mode
=
self
.
drop_and_pad
,
)
# Perform expert parallel AlltoAll communication
if
self
.
cuda_sync_point
==
"before_ep_alltoall"
:
torch
.
cuda
.
current_stream
().
synchronize
()
return
permutated_local_input_tokens
,
tokens_per_expert
def
alltoall_token_perm2
(
self
,
global_input_tokens
):
# Permutation 2: AlltoAll output to expert input if num_local_experts > 1
if
self
.
num_local_experts
>
1
:
if
not
self
.
drop_and_pad
:
global_input_tokens
,
self
.
reversed_global_input_permutation_mapping
=
permute
(
global_input_tokens
,
self
.
global_input_tokens_local_experts_indices
)
else
:
global_input_tokens
=
global_input_tokens
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
-
1
)
global_input_tokens
=
(
global_input_tokens
.
transpose
(
0
,
1
)
.
reshape
(
self
.
num_local_experts
*
self
.
ep_size
*
self
.
capacity
,
-
1
)
.
contiguous
()
)
if
self
.
cuda_sync_point
==
"before_finish"
:
torch
.
cuda
.
current_stream
().
synchronize
()
return
global_input_tokens
def
alltoall_token_unperm1
(
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
=
None
,
):
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert
bias
is
None
,
"Bias is not supported in MoEAlltoAllTokenDispatcher"
# Unpermutation 2: expert output to AlltoAll input
if
self
.
num_local_experts
>
1
:
if
not
self
.
drop_and_pad
:
hidden_states
=
unpermute
(
hidden_states
,
self
.
reversed_global_input_permutation_mapping
,
)
else
:
hidden_states
=
hidden_states
.
reshape
(
self
.
num_local_experts
,
self
.
ep_size
,
self
.
capacity
,
-
1
)
hidden_states
=
(
hidden_states
.
transpose
(
0
,
1
)
.
reshape
(
self
.
ep_size
*
self
.
num_local_experts
*
self
.
capacity
,
-
1
)
.
contiguous
()
)
return
hidden_states
def
alltoall_token_unperm2
(
self
,
permutated_local_input_tokens
,
probs
=
None
):
# Unpermutation 1: AlltoAll output to output
probs
=
probs
if
probs
is
not
None
else
self
.
probs
output
=
unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
probs
=
probs
,
padded_mode
=
self
.
drop_and_pad
,
restore_shape
=
self
.
hiddden_shape_before_permute
,
)
# Reshape the output tensor
output
=
output
.
view
(
self
.
hidden_shape
)
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
output
.
requires_grad
,
keep_graph
=
True
)
return
output
,
None
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/utils.py
0 → 100644
View file @
770fa304
import
torch
from
torch.autograd.variable
import
Variable
from
megatron.core.pipeline_parallel
import
p2p_communication
def
detach_tensor
(
tensor
,
checkpoint_forward
=
False
):
if
checkpoint_forward
:
return
tensor
if
tensor
is
None
:
return
None
detached_tensor
=
tensor
.
detach
()
detached_tensor
.
requires_grad
=
True
return
detached_tensor
def
run_graph_backward
(
graph
,
output_tensor_grad
=
None
,
keep_graph
=
False
,
keep_grad
=
False
):
grad_tensor
=
output_tensor_grad
if
output_tensor_grad
is
None
and
graph
[
1
]
is
not
None
and
graph
[
1
].
grad
is
not
None
:
grad_tensor
=
graph
[
1
].
grad
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
graph
[
0
],),
grad_tensors
=
(
grad_tensor
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
if
not
keep_graph
:
graph
[
0
].
untyped_storage
().
resize_
(
0
)
if
not
keep_grad
:
grad_tensor
.
untyped_storage
().
resize_
(
0
)
class
NoopLayerGraph
:
def
__init__
(
self
,
layer_input
,
layer_output
,
layer
,
checkpointed
=
False
):
self
.
layer_input
=
layer_input
if
not
checkpointed
:
self
.
unperm2_graph
=
(
layer_output
,
None
)
else
:
self
.
unperm2_graph
=
(
None
,
None
)
self
.
checkpointed
=
checkpointed
self
.
layer
=
layer
def
record_layer_inputs
(
self
,
*
args
):
self
.
layer_inputs
=
args
class
LayerGraph
:
def
__init__
(
self
,
saved_graph_and_graph_inputs
,
recompute_needed_tensors
,
input_splits
,
output_splits
,
layer
,
checkpointed
=
False
):
if
not
checkpointed
:
self
.
attn_graph
=
saved_graph_and_graph_inputs
[
0
]
self
.
pre_mlp_layernorm_graph
=
saved_graph_and_graph_inputs
[
1
]
self
.
router_graph
=
saved_graph_and_graph_inputs
[
2
]
self
.
perm1_graph
=
saved_graph_and_graph_inputs
[
3
]
self
.
perm_a2a_graph
=
saved_graph_and_graph_inputs
[
4
]
self
.
perm2_graph
=
saved_graph_and_graph_inputs
[
5
]
self
.
grouped_mlp_graph
=
saved_graph_and_graph_inputs
[
6
]
self
.
unperm1_graph
=
saved_graph_and_graph_inputs
[
7
]
self
.
unperm_a2a_graph
=
saved_graph_and_graph_inputs
[
8
]
self
.
unperm2_graph
=
saved_graph_and_graph_inputs
[
9
]
self
.
shared_experts_graph
=
saved_graph_and_graph_inputs
[
10
]
else
:
self
.
unperm2_graph
=
(
None
,
None
)
self
.
layer_input
=
saved_graph_and_graph_inputs
[
-
1
]
self
.
recompute_needed_tensors
=
recompute_needed_tensors
self
.
input_splits
=
input_splits
self
.
output_splits
=
output_splits
self
.
checkpointed
=
checkpointed
self
.
layer
=
layer
self
.
is_moe_layer
=
hasattr
(
layer
,
'mlp'
)
and
hasattr
(
layer
.
mlp
,
'experts'
)
def
record_layer_inputs
(
self
,
*
args
):
self
.
layer_inputs
=
args
class
P2PCommParams
:
tensor_shape
=
None
config
=
None
def
__init__
(
self
,
send_next
=
False
,
send_prev
=
False
,
recv_next
=
False
,
recv_prev
=
False
):
self
.
send_next
=
send_next
self
.
send_prev
=
send_prev
self
.
recv_next
=
recv_next
self
.
recv_prev
=
recv_prev
def
__str__
(
self
):
return
f
'send next:
{
self
.
send_next
}
send_prev:
{
self
.
send_prev
}
recv_next:
{
self
.
recv_next
}
recv_prev:
{
self
.
recv_prev
}
'
class
P2PCommOutput
:
def
__init__
(
self
,
input_tensor
=
None
,
output_tensor_grad
=
None
,
fwd_wait_handles
=
None
,
bwd_wait_handles
=
None
,
input_tensor_grad
=
None
):
self
.
input_tensor
=
input_tensor
self
.
fwd_wait_handles
=
fwd_wait_handles
self
.
output_tensor_grad
=
output_tensor_grad
self
.
bwd_wait_handles
=
bwd_wait_handles
self
.
input_tensor_grad
=
input_tensor_grad
def
is_p2p_comm_needed
(
pp_comm_params
:
P2PCommParams
):
return
pp_comm_params
is
not
None
and
\
(
pp_comm_params
.
send_next
or
pp_comm_params
.
send_prev
or
pp_comm_params
.
recv_next
or
pp_comm_params
.
recv_prev
)
def
p2p_comm_helper
(
comm_params
:
P2PCommParams
,
tensor_tosend
):
assert
not
(
comm_params
.
send_next
and
comm_params
.
send_prev
)
assert
not
(
comm_params
.
recv_next
and
comm_params
.
recv_prev
)
tensor_send_next
=
None
if
comm_params
.
send_next
:
tensor_send_next
=
tensor_tosend
tensor_send_prev
=
None
if
comm_params
.
send_prev
:
tensor_send_prev
=
tensor_tosend
tensor_recv_prev
,
tensor_recv_next
,
p2p_handles
=
p2p_communication
.
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
comm_params
.
recv_prev
,
recv_next
=
comm_params
.
recv_next
,
tensor_shape
=
comm_params
.
tensor_shape
,
wait_on_reqs
=
False
,
config
=
comm_params
.
config
)
if
comm_params
.
recv_next
:
return
tensor_recv_next
,
p2p_handles
elif
comm_params
.
recv_prev
:
return
tensor_recv_prev
,
p2p_handles
else
:
return
None
,
p2p_handles
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/weight_grad_store.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import
operator
import
queue
from
functools
import
reduce
import
torch
import
torch_npu
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_world_size
)
from
megatron.training
import
get_args
from
mindspeed.ops.gmm
import
GMMFunction
from
mindspeed.ops.npu_groupmatmul_add
import
npu_groupmatmul_add_fp32
def
gather
(
input_slice
,
stream
):
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input_slice
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
torch
.
empty
(
dim_size
,
dtype
=
input_slice
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
None
forward_event
=
torch
.
npu
.
Event
()
forward_event
.
record
()
with
torch
.
no_grad
():
with
torch_npu
.
npu
.
stream
(
stream
):
stream
.
wait_event
(
forward_event
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input_slice
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
return
all_gather_buffer
,
handle
class
WeightGradStore
:
cache
=
[]
weight_grad_queue
=
queue
.
Queue
()
store_grad_cache
=
[]
grad_store
=
[]
gather_stream
=
None
is_decoupleBlock
=
False
@
classmethod
def
put
(
cls
,
total_input
,
grad_output
,
weight
,
sequence_parallel
,
in_row
=
False
):
cls
.
cache
.
append
((
total_input
,
grad_output
,
weight
,
sequence_parallel
,
in_row
))
@
classmethod
def
flush_chunk_grad
(
cls
):
cls
.
weight_grad_queue
.
put
(
cls
.
cache
)
cls
.
cache
=
[]
@
classmethod
def
start_decouple
(
cls
):
cls
.
is_decoupleBlock
=
True
@
classmethod
def
end_decouple
(
cls
):
cls
.
is_decoupleBlock
=
False
@
classmethod
def
overlap_all_gather
(
cls
):
# used for grad_output all gather in RowParallel and input all gather in ColumnParallel.
if
len
(
cls
.
cache
)
>
0
:
[
input_
,
grad_output_slice
,
weight
,
sequence_parallel
,
in_row
]
=
cls
.
cache
.
pop
(
0
)
if
not
sequence_parallel
:
return
(
input_
,
grad_output_slice
,
weight
,
sequence_parallel
,
in_row
),
None
if
not
in_row
:
total_input
,
handle
=
gather
(
input_
,
cls
.
gather_stream
)
grad_output
=
grad_output_slice
else
:
grad_output
,
handle
=
gather
(
grad_output_slice
,
cls
.
gather_stream
)
total_input
=
input_
return
[
total_input
,
grad_output
,
weight
,
sequence_parallel
,
in_row
],
handle
else
:
raise
Exception
(
"All Gather empty queue."
)
@
classmethod
def
overlap_matmul
(
cls
,
grad_store_cache
):
total_input
,
grad_output
,
weight
,
sequence_parallel
,
in_row
=
grad_store_cache
args
=
get_args
()
if
hasattr
(
weight
,
'gmm_weight'
):
inputs
,
group_list
,
group_list_data_type
=
total_input
if
get_args
().
gemm_gradient_accumulation_fusion
:
npu_groupmatmul_add_fp32
(
inputs
,
grad_output
,
group_list
,
weight
.
main_grad
)
else
:
grad_weight
=
GMMFunction
.
builder
.
load
().
npu_gmm
([
inputs
.
t
()],
[
grad_output
],
[],
group_list
,
2
,
0
)[
0
]
weight
.
main_grad
.
data
.
add_
(
grad_weight
.
view
(
-
1
,
weight
.
shape
[
-
1
]))
inputs
.
untyped_storage
().
resize_
(
0
)
grad_output
.
untyped_storage
().
resize_
(
0
)
else
:
if
len
(
grad_output
.
shape
)
>
2
:
grad_output
=
grad_output
.
contiguous
()
sb
=
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
]
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
sb
,
grad_output
.
shape
[
2
]
)
total_input
=
total_input
.
view
(
sb
,
total_input
.
shape
[
2
]
)
if
get_args
().
gradient_accumulation_fusion
:
import
fused_weight_gradient_mlp_cuda
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
weight
.
main_grad
.
data
.
add_
(
grad_weight
)
total_input
.
untyped_storage
().
resize_
(
0
)
grad_output
.
untyped_storage
().
resize_
(
0
)
@
classmethod
def
pop
(
cls
,
overlap_arg
=
None
):
if
len
(
cls
.
cache
)
==
0
:
return
if
cls
.
gather_stream
is
None
:
cls
.
gather_stream
=
torch_npu
.
npu
.
Stream
(
device
=
torch
.
npu
.
current_device
())
(
input_
,
grad_output_slice
,
weight
,
sequence_parallel
,
in_row
),
handle
=
cls
.
overlap_all_gather
()
if
not
sequence_parallel
or
get_args
().
moe_fb_overlap
:
grad_output
=
grad_output_slice
else
:
grad_output
,
handle
=
gather
(
grad_output_slice
,
cls
.
gather_stream
)
cls
.
store_grad_cache
=
(
input_
,
grad_output
,
weight
,
sequence_parallel
,
in_row
)
while
len
(
cls
.
cache
)
>
0
:
if
handle
is
not
None
:
handle
.
wait
()
next_grad_cache
,
handle
=
cls
.
overlap_all_gather
()
cls
.
overlap_matmul
(
cls
.
store_grad_cache
)
cls
.
store_grad_cache
=
next_grad_cache
if
handle
is
not
None
:
handle
.
wait
()
cls
.
overlap_matmul
(
cls
.
store_grad_cache
)
cls
.
store_grad_cache
=
None
@
classmethod
def
pop_single
(
cls
):
if
cls
.
weight_grad_queue
.
empty
():
return
cache_list
=
cls
.
weight_grad_queue
.
get
()
assert
len
(
cls
.
cache
)
==
0
cls
.
cache
=
cache_list
cls
.
pop
()
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/__init__.py
0 → 100644
View file @
770fa304
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
.fwd
import
*
from
.bwd
import
*
from
.fwdbwd
import
*
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/bwd.py
0 → 100644
View file @
770fa304
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
from
megatron.core
import
parallel_state
from
megatron.training
import
get_args
from
megatron.core.transformer.moe.moe_utils
import
permute
from
mindspeed.core.transformer.moe.comm_utils
import
async_all_to_all
,
async_all_gather
,
async_reduce_scatter
from
mindspeed.model.transformer
import
should_recompute_activation
from
mindspeed.core.transformer.moe.moe_utils
import
get_prob_backward_need_tensors
from
..modules.weight_grad_store
import
WeightGradStore
from
..modules.utils
import
run_graph_backward
def
transformer_layer_backward_moe
(
layer_output_grad
,
layer_graph
):
self
=
layer_graph
args
=
get_args
()
in_detach_stage
=
WeightGradStore
.
is_decoupleBlock
dispached_input
,
fc1_out
,
act_out
,
probs
,
indices
,
global_input_tokens_local_experts_indices
=
self
.
recompute_needed_tensors
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
if
tp_size
>
1
:
shared_expert_grad
=
layer_output_grad
if
layer_output_grad
is
not
None
else
self
.
unperm2_graph
[
1
].
grad
_
,
backward_ag_shared
,
backward_ag_shared_handle
=
async_all_gather
(
shared_expert_grad
,
parallel_state
.
get_tensor_model_parallel_group
()
)
else
:
backward_ag_shared
=
layer_output_grad
if
layer_output_grad
is
not
None
else
self
.
unperm2_graph
[
1
].
grad
backward_ag_shared_handle
=
None
run_graph_backward
(
self
.
unperm2_graph
,
layer_output_grad
,
keep_grad
=
True
)
if
backward_ag_shared_handle
is
not
None
:
backward_ag_shared_handle
.
wait
()
backward_ag_shared_handle
=
None
if
layer_output_grad
is
not
None
:
layer_output_grad
.
untyped_storage
().
resize_
(
0
)
_
,
unperm1_out_grad
,
handle
=
async_all_to_all
(
self
.
unperm_a2a_graph
[
1
].
grad
,
self
.
output_splits
,
self
.
input_splits
,
ep_group
)
# overlap alltoall by shared experts backward
if
self
.
shared_experts_graph
[
0
]
is
not
None
:
run_graph_backward
(
self
.
shared_experts_graph
,
backward_ag_shared
)
if
get_args
().
moe_zero_memory
==
'level0'
or
should_recompute_activation
(
self
.
layer
.
layer_number
):
with
torch
.
no_grad
():
recompute_act_out
=
self
.
layer
.
mlp
.
experts
.
activation_func
(
fc1_out
)
act_out
.
untyped_storage
().
resize_
(
recompute_act_out
.
untyped_storage
().
size
())
act_out
.
untyped_storage
().
copy_
(
recompute_act_out
.
untyped_storage
())
recompute_act_out
.
untyped_storage
().
resize_
(
0
)
handle
.
wait
()
handle
=
None
# recomp permute1 and overlap all2all
if
get_args
().
moe_zero_memory
==
'level0'
:
with
torch
.
no_grad
():
input_before_perm1
=
self
.
pre_mlp_layernorm_graph
[
0
]
def
recomp_token_permutation1
(
hidden_states
,
indices
):
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
permutated_local_input_tokens
,
_
=
permute
(
hidden_states
,
indices
)
return
permutated_local_input_tokens
perm1_out
=
recomp_token_permutation1
(
input_before_perm1
,
indices
)
_
,
perm_a2a_out
,
perm_a2a_handle
=
async_all_to_all
(
perm1_out
,
self
.
output_splits
,
self
.
input_splits
,
ep_group
)
run_graph_backward
(
self
.
unperm1_graph
,
unperm1_out_grad
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
self
.
grouped_mlp_graph
,
keep_grad
=
True
)
# keep for dw commputation
if
not
in_detach_stage
:
WeightGradStore
.
end_decouple
()
run_graph_backward
(
self
.
perm2_graph
,
keep_graph
=
True
)
# keep for dw commutation
if
get_args
().
moe_zero_memory
==
'level0'
:
perm_a2a_handle
.
wait
()
perm_a2a_handle
=
None
_
,
perm1_out_grad
,
handle
=
async_all_to_all
(
self
.
perm_a2a_graph
[
1
].
grad
,
self
.
input_splits
,
self
.
output_splits
,
ep_group
)
if
get_args
().
moe_zero_memory
==
'level0'
:
with
torch
.
no_grad
():
recompute_fc1_input
,
_
=
permute
(
perm_a2a_out
,
global_input_tokens_local_experts_indices
)
perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
# restore fc1 input for dw computation
dispached_input
.
untyped_storage
().
resize_
(
recompute_fc1_input
.
untyped_storage
().
size
())
dispached_input
.
untyped_storage
().
copy_
(
recompute_fc1_input
.
untyped_storage
())
recompute_fc1_input
.
untyped_storage
().
resize_
(
0
)
# dw computation
if
not
in_detach_stage
:
WeightGradStore
.
pop
()
handle
.
wait
()
handle
=
None
run_graph_backward
(
self
.
perm1_graph
,
perm1_out_grad
)
run_graph_backward
(
self
.
router_graph
)
run_graph_backward
(
self
.
pre_mlp_layernorm_graph
)
run_graph_backward
(
self
.
attn_graph
)
self
.
recompute_needed_tensors
=
[
None
for
_
in
range
(
len
(
self
.
recompute_needed_tensors
))]
return
self
.
layer_input
.
grad
def
transformer_layer_backward_dense
(
layer_output_grad
,
layer_graph
):
run_graph_backward
(
layer_graph
.
unperm2_graph
,
layer_output_grad
)
run_graph_backward
(
layer_graph
.
pre_mlp_layernorm_graph
)
run_graph_backward
(
layer_graph
.
attn_graph
)
return
layer_graph
.
layer_input
.
grad
def
transformer_layer_backward_noop
(
layer_output_grad
,
layer_graph
):
run_graph_backward
(
layer_graph
.
unperm2_graph
,
layer_output_grad
,
keep_grad
=
True
)
return
layer_graph
.
layer_input
.
grad
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/fwd.py
0 → 100644
View file @
770fa304
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
from
megatron.core.utils
import
make_sharded_tensor_for_checkpoint
,
make_viewless_tensor
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.training
import
get_args
from
mindspeed.core.transformer.moe.comm_utils
import
async_all_to_all
,
async_all_gather
,
async_reduce_scatter
from
mindspeed.core.tensor_parallel.random
import
CheckpointWithoutOutput
from
mindspeed.core.transformer.moe.moe_utils
import
AG_SHARED_EXPERTS_INPUTS
from
mindspeed.model.transformer
import
should_recompute_activation
from
..modules.token_dispatcher
import
(
alltoall_token_perm1
,
alltoall_token_perm2
,
alltoall_token_unperm1
,
alltoall_token_unperm2
)
from
..modules.attention
import
attention_forward
from
..modules.utils
import
(
detach_tensor
,
NoopLayerGraph
,
LayerGraph
,
)
def
router_forward
(
self
,
hidden_states
):
probs
,
indices
=
self
.
mlp
.
router
(
hidden_states
)
return
probs
,
indices
def
transformer_layer_forward_moe
(
self
,
hidden_states
,
attention_mask
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
checkpoint
=
False
):
# hidden_states: [s, b, h]
args
=
get_args
()
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
tp_group
=
parallel_state
.
get_tensor_model_parallel_group
()
use_shared_experts
=
hasattr
(
self
.
mlp
,
'shared_experts'
)
and
self
.
mlp
.
shared_experts
is
not
None
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
self
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
attention_out
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual2
=
detached_attention_out
# Layer Norm after attention
if
recomp_norm
:
self
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
norm_ckpt2
.
checkpoint
(
self
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
detached_attention_out
)
# MLP.
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
,
checkpoint_forward
=
checkpoint
)
if
tp_size
>
1
and
use_shared_experts
:
# shared experts tp communication
_
,
shared_experts_input
,
shared_experts_allgather_handle
=
async_all_gather
(
detached_mlp_input
,
tp_group
,
is_use_get_global_memory_buffer
=
True
)
AG_SHARED_EXPERTS_INPUTS
.
append
((
shared_experts_input
,
shared_experts_allgather_handle
))
else
:
shared_experts_input
,
shared_experts_allgather_handle
=
detached_mlp_input
,
None
# Router forward.
probs
,
indices
=
router_forward
(
self
,
detached_mlp_input
)
shared_expert_output
=
None
# Token Perm1 Forward
probs_detached
=
detach_tensor
(
probs
,
checkpoint_forward
=
checkpoint
)
perm1_out
,
tokens_per_expert
=
alltoall_token_perm1
(
self
.
mlp
.
token_dispatcher
,
detached_mlp_input
,
probs_detached
,
indices
)
if
shared_experts_allgather_handle
is
not
None
:
# overlap shared experts tp comm by token perm1.
shared_experts_allgather_handle
.
wait
()
# Async Perm A2A.
_
,
perm_a2a_out
,
perm_a2a_handle
=
async_all_to_all
(
perm1_out
,
self
.
mlp
.
token_dispatcher
.
output_splits
,
self
.
mlp
.
token_dispatcher
.
input_splits
,
ep_group
)
# Shared Experts Forward.
if
use_shared_experts
:
shared_expert_output
,
_
=
self
.
mlp
.
shared_experts
(
detached_mlp_input
)
if
recomp_norm
:
self
.
norm_ckpt2
.
discard_output
()
# overlap perm a2a by shared experts computation.
perm_a2a_handle
.
wait
()
# perm1_out tensor storage is not need by backward,
# but backward func of perm1_out is needed, so resize the storage but keep tensor.
perm1_out
.
untyped_storage
().
resize_
(
0
)
if
tp_size
>
1
and
use_shared_experts
:
# tp comm for shared experts
share_experts_graph
,
shared_expert_output
,
rs_shared_experts_handle
=
async_reduce_scatter
(
shared_expert_output
,
tp_group
)
else
:
share_experts_graph
=
shared_expert_output
rs_shared_experts_handle
=
None
detached_perm_a2a_out
=
detach_tensor
(
perm_a2a_out
,
checkpoint_forward
=
checkpoint
)
# Token Perm2 Forward.
dispached_input
=
alltoall_token_perm2
(
self
.
mlp
.
token_dispatcher
,
detached_perm_a2a_out
)
perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
# Grouped MLP Forward
detached_dispached_input
=
detach_tensor
(
dispached_input
,
checkpoint_forward
=
checkpoint
)
(
expert_output
,
fc1_output
,
act_out
),
_
=
self
.
mlp
.
experts
(
detached_dispached_input
,
tokens_per_expert
)
if
args
.
moe_zero_memory
==
'level0'
:
dispached_input
.
untyped_storage
().
resize_
(
0
)
recompute_needed_tensors
=
[
dispached_input
,
fc1_output
,
act_out
,
probs
,
indices
,
self
.
mlp
.
token_dispatcher
.
global_input_tokens_local_experts_indices
]
else
:
if
should_recompute_activation
(
self
.
layer_number
):
recompute_needed_tensors
=
[
None
,
fc1_output
,
act_out
,
None
,
None
,
None
]
else
:
recompute_needed_tensors
=
[
None
,
None
,
None
,
None
,
None
,
None
]
detached_expert_output
=
detach_tensor
(
expert_output
,
checkpoint_forward
=
checkpoint
)
# Token Unperm1 Forward
unperm1_out
=
alltoall_token_unperm1
(
self
.
mlp
.
token_dispatcher
,
detached_expert_output
,
None
)
expert_output
.
untyped_storage
().
resize_
(
0
)
if
rs_shared_experts_handle
is
not
None
:
# overlap shared experts tp comm by token perm2 + gmm
rs_shared_experts_handle
.
wait
()
# share_experts_graph tensor storage is not need by backward,
# but backward func of share_experts_graph is needed, so resize the storage but keep tensor.
share_experts_graph
.
untyped_storage
().
resize_
(
0
)
# Launch Token Unperm2 A2A
_
,
unperm_a2a_out
,
unperm_a2a_handle
=
async_all_to_all
(
unperm1_out
,
self
.
mlp
.
token_dispatcher
.
input_splits
,
self
.
mlp
.
token_dispatcher
.
output_splits
,
ep_group
)
unperm_a2a_handle
.
wait
()
# unperm1_out tensor storage is not need by backward,
# but backward func of unperm1_out is needed, so resize the storage but keep tensor.
unperm1_out
.
untyped_storage
().
resize_
(
0
)
detached_unperm_a2a_out
=
detach_tensor
(
unperm_a2a_out
,
checkpoint_forward
=
checkpoint
)
route_expert_output
,
_
=
alltoall_token_unperm2
(
self
.
mlp
.
token_dispatcher
,
detached_unperm_a2a_out
)
if
use_shared_experts
:
detached_shared_expert_output
=
detach_tensor
(
shared_expert_output
,
checkpoint_forward
=
checkpoint
)
mlp_output
=
route_expert_output
+
detached_shared_expert_output
shared_expert_output
.
untyped_storage
().
resize_
(
0
)
else
:
detached_shared_expert_output
=
None
share_experts_graph
=
None
mlp_output
=
route_expert_output
if
recomp_norm
:
mlp_output
.
register_hook
(
self
.
norm_ckpt2
.
recompute
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
(
mlp_output
,
None
),
residual2
,
self
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
saved_tensors
=
(
(
attention_out
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
probs
,
probs_detached
),
(
perm1_out
,
None
),
# perm1 graph
(
None
,
detached_perm_a2a_out
),
(
dispached_input
,
detached_dispached_input
),
# perm2 graph
(
expert_output
,
detached_expert_output
),
# grouped mlp graph
(
unperm1_out
,
None
),
# unperm1 graph
(
None
,
detached_unperm_a2a_out
),
(
output
,
None
),
# unperm2 graph
(
share_experts_graph
,
detached_shared_expert_output
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
recompute_needed_tensors
,
self
.
mlp
.
token_dispatcher
.
input_splits
,
self
.
mlp
.
token_dispatcher
.
output_splits
,
self
,
checkpointed
=
checkpoint
)
return
output
,
context
,
graph
def
transformer_layer_forward_dense
(
self
,
hidden_states
,
attention_mask
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
checkpoint
=
False
):
# hidden_states: [s, b, h]
args
=
get_args
()
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
self
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
attention_graph
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual2
=
detached_attention_out
if
recomp_norm
:
self
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
norm_ckpt2
.
checkpoint
(
self
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
detached_attention_out
)
# MLP.
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
,
checkpoint_forward
=
checkpoint
)
mlp_output_with_bias
=
self
.
mlp
(
detached_mlp_input
)
if
recomp_norm
:
self
.
norm_ckpt2
.
discard_output
()
mlp_output_with_bias
[
0
].
register_hook
(
self
.
norm_ckpt2
.
recompute
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual2
,
self
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
saved_tensors
=
(
(
attention_graph
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
output
,
None
),
(
None
,
None
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
[],
None
,
None
,
self
,
checkpointed
=
checkpoint
)
return
output
,
context
,
graph
def
transformer_layer_forward_noop
(
self
,
hidden_states
,
attention_mask
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
checkpoint
=
False
):
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
output
=
detached_layer_input
.
clone
()
return
output
,
context
,
NoopLayerGraph
(
detached_layer_input
,
output
,
self
,
checkpointed
=
checkpoint
)
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/overlap_funcs/fwdbwd.py
0 → 100644
View file @
770fa304
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
contextlib
import
nullcontext
import
torch
from
megatron.core.utils
import
make_sharded_tensor_for_checkpoint
,
make_viewless_tensor
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.training
import
get_args
from
megatron.core.transformer.moe.moe_utils
import
permute
from
mindspeed.core.transformer.moe.comm_utils
import
async_all_to_all
,
async_all_gather
,
async_reduce_scatter
from
mindspeed.model.transformer
import
should_recompute_activation
from
mindspeed.core.tensor_parallel.random
import
CheckpointWithoutOutput
from
mindspeed.core.transformer.moe.moe_utils
import
AG_SHARED_EXPERTS_INPUTS
from
..modules.token_dispatcher
import
(
alltoall_token_perm1
,
alltoall_token_perm2
,
alltoall_token_unperm1
,
alltoall_token_unperm2
)
from
..modules.weight_grad_store
import
WeightGradStore
from
..modules.attention
import
(
attention_forward
,
set_async_alltoall_inputs
,
get_async_alltoall_outputs
)
from
..modules.utils
import
(
detach_tensor
,
run_graph_backward
,
LayerGraph
,
is_p2p_comm_needed
,
p2p_comm_helper
,
P2PCommOutput
,
P2PCommParams
)
def
router_forward
(
self
,
hidden_states
):
probs
,
indices
=
self
.
mlp
.
router
(
hidden_states
)
return
probs
,
indices
def
transformer_layer_forward_dense_backward_moe_overlaping
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
=
None
,
bwd_layer_graph
:
LayerGraph
=
None
,
bwd_unperm_a2a_handle
=
None
,
next_bwd_layer_graph
:
LayerGraph
=
None
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
checkpoint
=
False
):
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
if
checkpoint
:
checkpoint_context
=
torch
.
no_grad
()
else
:
checkpoint_context
=
nullcontext
()
args
=
get_args
()
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
bwd_dispached_input
,
bwd_fc1_out
,
bwd_act_out
,
bwd_probs
,
bwd_indices
,
global_input_tokens_local_experts_indices
=
bwd_layer_graph
.
recompute_needed_tensors
# Unperm2 Bwd
# check if backward unpermutation alltoall is launched at bwd layer before
if
bwd_unperm_a2a_handle
is
None
:
run_graph_backward
(
bwd_layer_graph
.
unperm2_graph
,
bwd_layer_output_grad
)
# Async Unperm A2A
_
,
unperm1_out_grad
,
bwd_unperm_a2a_handle
=
async_all_to_all
(
bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
bwd_layer_graph
.
output_splits
,
bwd_layer_graph
.
input_splits
,
ep_group
)
else
:
unperm1_out_grad
=
bwd_layer_output_grad
if
args
.
moe_zero_memory
==
'level0'
:
with
torch
.
no_grad
():
bwd_input_before_perm1
=
bwd_layer_graph
.
pre_mlp_layernorm_graph
[
0
]
def
recomp_token_permutation1
(
hidden_states
,
indices
):
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
permutated_local_input_tokens
,
_
=
permute
(
hidden_states
,
indices
)
return
permutated_local_input_tokens
bwd_perm1_out
=
recomp_token_permutation1
(
bwd_input_before_perm1
,
bwd_indices
)
with
checkpoint_context
:
# Atten Fwd
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
fwd_layer
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
attention_graph
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual2
=
detached_attention_out
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
fwd_layer
.
norm_ckpt2
.
checkpoint
(
fwd_layer
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
fwd_layer
.
pre_mlp_layernorm
(
detached_attention_out
)
if
args
.
moe_zero_memory
==
'level0'
:
_
,
bwd_perm_a2a_out
,
bwd_recomp_perm_a2a_handle
=
async_all_to_all
(
bwd_perm1_out
,
bwd_layer_graph
.
output_splits
,
bwd_layer_graph
.
input_splits
,
ep_group
,
event
=
bwd_unperm_a2a_handle
,
stream
=
torch
.
npu
.
current_stream
()
)
if
args
.
moe_zero_memory
==
'level0'
or
should_recompute_activation
(
bwd_layer_graph
.
layer
.
layer_number
):
with
torch
.
no_grad
():
recompute_act_out
=
bwd_layer_graph
.
layer
.
mlp
.
experts
.
activation_func
(
bwd_fc1_out
)
bwd_act_out
.
untyped_storage
().
resize_
(
recompute_act_out
.
untyped_storage
().
size
())
bwd_act_out
.
untyped_storage
().
copy_
(
recompute_act_out
.
untyped_storage
())
recompute_act_out
.
untyped_storage
().
resize_
(
0
)
bwd_unperm_a2a_handle
.
wait
()
bwd_unperm_a2a_handle
=
None
run_graph_backward
(
bwd_layer_graph
.
unperm1_graph
,
unperm1_out_grad
)
unperm1_out_grad
.
untyped_storage
().
resize_
(
0
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
grouped_mlp_graph
,
keep_grad
=
True
)
# keep for dw
WeightGradStore
.
end_decouple
()
run_graph_backward
(
bwd_layer_graph
.
perm2_graph
,
keep_graph
=
True
)
# keep for dw
if
args
.
moe_zero_memory
==
'level0'
:
with
torch
.
no_grad
():
bwd_recomp_perm_a2a_handle
.
wait
()
bwd_recomp_perm_a2a_handle
=
None
recompute_fc1_input
,
_
=
permute
(
bwd_perm_a2a_out
,
global_input_tokens_local_experts_indices
)
bwd_perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
if
tp_size
>
1
:
shared_expert_grad
=
bwd_layer_graph
.
shared_experts_graph
[
1
].
grad
_
,
backward_ag_shared
,
backward_ag_shared_handle
=
async_all_gather
(
shared_expert_grad
,
parallel_state
.
get_tensor_model_parallel_group
()
)
else
:
backward_ag_shared
=
bwd_layer_graph
.
shared_experts_graph
[
1
].
grad
backward_ag_shared_handle
=
None
_
,
perm1_out_grad
,
bwd_perm_a2a_handle
=
async_all_to_all
(
bwd_layer_graph
.
perm_a2a_graph
[
1
].
grad
,
bwd_layer_graph
.
input_splits
,
bwd_layer_graph
.
output_splits
,
ep_group
,
event
=
backward_ag_shared_handle
)
# Grouped MLP dw computation
with
checkpoint_context
:
# MLP Forward
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
,
checkpoint_forward
=
checkpoint
)
mlp_output_with_bias
=
fwd_layer
.
mlp
(
detached_mlp_input
)
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
.
discard_output
()
mlp_output_with_bias
[
0
].
register_hook
(
fwd_layer
.
norm_ckpt2
.
recompute
)
bwd_perm_a2a_handle
.
wait
()
bwd_perm_a2a_handle
=
None
run_graph_backward
(
bwd_layer_graph
.
perm1_graph
,
perm1_out_grad
)
perm1_out_grad
.
untyped_storage
().
resize_
(
0
)
WeightGradStore
.
start_decouple
()
if
backward_ag_shared_handle
is
not
None
:
backward_ag_shared_handle
.
wait
()
backward_ag_shared_handle
=
None
shared_expert_grad
.
untyped_storage
().
resize_
(
0
)
run_graph_backward
(
bwd_layer_graph
.
shared_experts_graph
,
backward_ag_shared
,
keep_grad
=
True
)
# dw computation
WeightGradStore
.
end_decouple
()
run_graph_backward
(
bwd_layer_graph
.
router_graph
)
run_graph_backward
(
bwd_layer_graph
.
pre_mlp_layernorm_graph
,
keep_graph
=
True
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
attn_graph
,
keep_grad
=
True
)
WeightGradStore
.
end_decouple
()
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
run_graph_backward
(
next_bwd_layer_graph
.
unperm2_graph
,
bwd_layer_graph
.
layer_input
.
grad
,
keep_graph
=
True
)
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
bwd_layer_graph
.
layer_input
.
grad
,
None
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
_
,
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
async_all_to_all
(
next_bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
next_bwd_layer_graph
.
output_splits
,
next_bwd_layer_graph
.
input_splits
,
ep_group
)
with
checkpoint_context
:
with
fwd_layer
.
bias_dropout_add_exec_handler
():
hidden_states
=
fwd_layer
.
mlp_bda
(
fwd_layer
.
training
,
fwd_layer
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual2
,
fwd_layer
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
# handle fwd p2p communication
next_iter_input_tensor
,
fwd_p2p_handles
=
None
,
None
fwd_pp_comm_params
=
pp_comm_params
if
is_p2p_comm_needed
(
fwd_pp_comm_params
):
next_iter_input_tensor
,
fwd_p2p_handles
=
p2p_comm_helper
(
fwd_pp_comm_params
,
output
)
# handle bwd p2p communication
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
None
,
None
if
is_p2p_comm_needed
(
bwd_pp_comm_params
):
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
p2p_comm_helper
(
bwd_pp_comm_params
,
bwd_layer_graph
.
layer_input
.
grad
)
if
args
.
moe_zero_memory
==
'level0'
:
# restore fc1 input for dw computation
bwd_dispached_input
.
untyped_storage
().
resize_
(
recompute_fc1_input
.
untyped_storage
().
size
())
bwd_dispached_input
.
untyped_storage
().
copy_
(
recompute_fc1_input
.
untyped_storage
())
recompute_fc1_input
.
untyped_storage
().
resize_
(
0
)
WeightGradStore
.
pop
()
saved_tensors
=
(
(
attention_graph
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
(
None
,
None
),
# perm2 graph
(
None
,
None
),
# grouped mlp graph
(
None
,
None
),
# unperm1 graph
(
None
,
None
),
(
output
,
None
),
# unperm2 graph
(
None
,
None
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
[],
None
,
None
,
fwd_layer
,
checkpointed
=
checkpoint
)
for
tensor
in
bwd_layer_graph
.
recompute_needed_tensors
:
if
tensor
is
not
None
:
tensor
.
untyped_storage
().
resize_
(
0
)
return
(
output
,
context
,
graph
,
(
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
),
P2PCommOutput
(
next_iter_input_tensor
,
next_iter_output_tensor_grad
,
fwd_p2p_handles
,
bwd_p2p_handles
,
bwd_layer_graph
.
layer_input
.
grad
))
def
transformer_layer_forward_moe_backward_dense_overlaping
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
=
None
,
bwd_layer_graph
:
LayerGraph
=
None
,
bwd_unperm_a2a_handle
=
None
,
next_bwd_layer_graph
:
LayerGraph
=
None
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
checkpoint
=
False
):
args
=
get_args
()
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
tp_group
=
parallel_state
.
get_tensor_model_parallel_group
()
use_shared_experts
=
hasattr
(
fwd_layer
.
mlp
,
'shared_experts'
)
and
fwd_layer
.
mlp
.
shared_experts
is
not
None
if
checkpoint
:
checkpoint_context
=
torch
.
no_grad
()
else
:
checkpoint_context
=
nullcontext
()
args
=
get_args
()
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
with
checkpoint_context
:
# Atten Fwd
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
fwd_layer
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
attention_graph
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
)
# Residual connection.
residual2
=
detached_attention_out
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
fwd_layer
.
norm_ckpt2
.
checkpoint
(
fwd_layer
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
fwd_layer
.
pre_mlp_layernorm
(
detached_attention_out
)
# MLP.
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
,
checkpoint_forward
=
checkpoint
)
probs
,
indices
=
router_forward
(
fwd_layer
,
detached_mlp_input
)
# Token Permutation Forward
probs_detached
=
detach_tensor
(
probs
,
checkpoint_forward
=
checkpoint
)
perm1_out
,
tokens_per_expert
=
alltoall_token_perm1
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_mlp_input
,
probs_detached
,
indices
)
_
,
perm_a2a_out
,
perm_a2a_handle
=
async_all_to_all
(
perm1_out
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
ep_group
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
unperm2_graph
,
bwd_layer_output_grad
,
keep_grad
=
True
)
# keep for dw
run_graph_backward
(
bwd_layer_graph
.
pre_mlp_layernorm_graph
,
keep_graph
=
True
)
WeightGradStore
.
end_decouple
()
perm_a2a_handle
.
wait
()
perm_a2a_handle
=
None
# Grouped MLP dw computation
with
checkpoint_context
:
detached_perm_a2a_out
=
detach_tensor
(
perm_a2a_out
,
checkpoint_forward
=
checkpoint
)
dispached_input
=
alltoall_token_perm2
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_perm_a2a_out
)
perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
if
tp_size
>
1
and
use_shared_experts
:
_
,
shared_experts_input
,
shared_experts_allgather_handle
=
async_all_gather
(
detached_mlp_input
,
tp_group
,
is_use_get_global_memory_buffer
=
True
)
AG_SHARED_EXPERTS_INPUTS
.
append
((
shared_experts_input
,
shared_experts_allgather_handle
))
else
:
shared_experts_input
,
shared_experts_allgather_handle
=
detached_mlp_input
,
None
# Grouped MLP Forward
detached_dispached_input
=
detach_tensor
(
dispached_input
,
checkpoint_forward
=
checkpoint
)
(
expert_output
,
fc1_output
,
act_out
),
_
=
fwd_layer
.
mlp
.
experts
(
detached_dispached_input
,
tokens_per_expert
)
if
args
.
moe_zero_memory
==
'level0'
:
dispached_input
.
untyped_storage
().
resize_
(
0
)
recompute_needed_tensors
=
[
dispached_input
,
fc1_output
,
act_out
,
probs
,
indices
,
fwd_layer
.
mlp
.
token_dispatcher
.
global_input_tokens_local_experts_indices
]
else
:
if
should_recompute_activation
(
fwd_layer
.
layer_number
):
recompute_needed_tensors
=
[
None
,
fc1_output
,
act_out
,
None
,
None
,
None
]
else
:
recompute_needed_tensors
=
[
None
,
None
,
None
,
None
,
None
,
None
]
detached_expert_output
=
detach_tensor
(
expert_output
,
checkpoint_forward
=
checkpoint
)
# Token Unpermutaion Forward
unperm1_out
=
alltoall_token_unperm1
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_expert_output
,
None
)
expert_output
.
untyped_storage
().
resize_
(
0
)
if
shared_experts_allgather_handle
is
not
None
:
shared_experts_allgather_handle
.
wait
()
shared_experts_allgather_handle
=
None
_
,
unperm_a2a_out
,
unperm_a2a_handle
=
async_all_to_all
(
unperm1_out
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
ep_group
)
share_experts_graph
=
None
if
use_shared_experts
:
shared_expert_output
,
_
=
fwd_layer
.
mlp
.
shared_experts
(
detached_mlp_input
)
if
tp_size
>
1
:
share_experts_graph
,
shared_expert_output
,
rs_shared_experts_handle
=
async_reduce_scatter
(
shared_expert_output
,
tp_group
)
rs_shared_experts_handle
.
wait
()
rs_shared_experts_handle
=
None
share_experts_graph
.
untyped_storage
().
resize_
(
0
)
else
:
share_experts_graph
=
shared_expert_output
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
.
discard_output
()
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
attn_graph
,
keep_grad
=
True
)
WeightGradStore
.
end_decouple
()
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
run_graph_backward
(
next_bwd_layer_graph
.
unperm2_graph
,
bwd_layer_graph
.
layer_input
.
grad
,
keep_graph
=
True
)
unperm_a2a_handle
.
wait
()
unperm_a2a_handle
=
None
unperm1_out
.
untyped_storage
().
resize_
(
0
)
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
bwd_layer_graph
.
layer_input
.
grad
,
None
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
_
,
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
async_all_to_all
(
next_bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
next_bwd_layer_graph
.
output_splits
,
next_bwd_layer_graph
.
input_splits
,
ep_group
)
with
checkpoint_context
:
detached_unperm_a2a_out
=
detach_tensor
(
unperm_a2a_out
,
checkpoint_forward
=
checkpoint
)
route_expert_output
,
_
=
alltoall_token_unperm2
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_unperm_a2a_out
)
if
hasattr
(
fwd_layer
.
mlp
,
'shared_experts'
)
and
fwd_layer
.
mlp
.
shared_experts
is
not
None
:
detached_shared_expert_output
=
detach_tensor
(
shared_expert_output
,
checkpoint_forward
=
checkpoint
)
mlp_output
=
route_expert_output
+
detached_shared_expert_output
shared_expert_output
.
untyped_storage
().
resize_
(
0
)
else
:
detached_shared_expert_output
=
None
mlp_output
=
route_expert_output
if
recomp_norm
:
mlp_output
.
register_hook
(
fwd_layer
.
norm_ckpt2
.
recompute
)
with
fwd_layer
.
bias_dropout_add_exec_handler
():
hidden_states
=
fwd_layer
.
mlp_bda
(
fwd_layer
.
training
,
fwd_layer
.
config
.
bias_dropout_fusion
)(
(
mlp_output
,
None
),
residual2
,
fwd_layer
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
# handle fwd p2p communication
next_iter_input_tensor
,
fwd_p2p_handles
=
None
,
None
fwd_pp_comm_params
=
pp_comm_params
if
is_p2p_comm_needed
(
fwd_pp_comm_params
):
next_iter_input_tensor
,
fwd_p2p_handles
=
p2p_comm_helper
(
fwd_pp_comm_params
,
output
)
# handle bwd p2p communication
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
None
,
None
if
is_p2p_comm_needed
(
bwd_pp_comm_params
):
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
p2p_comm_helper
(
bwd_pp_comm_params
,
bwd_layer_graph
.
layer_input
.
grad
)
WeightGradStore
.
pop
()
saved_tensors
=
(
(
attention_graph
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
probs
,
probs_detached
),
(
perm1_out
,
None
),
# perm1 graph
(
None
,
detached_perm_a2a_out
),
(
dispached_input
,
detached_dispached_input
),
# perm2 graph
(
expert_output
,
detached_expert_output
),
# grouped mlp graph
(
unperm1_out
,
None
),
# unperm1 graph
(
None
,
detached_unperm_a2a_out
),
(
output
,
None
),
# unperm2 graph
(
share_experts_graph
,
detached_shared_expert_output
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
recompute_needed_tensors
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
fwd_layer
,
checkpointed
=
checkpoint
)
for
tensor
in
bwd_layer_graph
.
recompute_needed_tensors
:
if
tensor
is
not
None
:
tensor
.
untyped_storage
().
resize_
(
0
)
return
(
output
,
context
,
graph
,
(
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
),
P2PCommOutput
(
next_iter_input_tensor
,
next_iter_output_tensor_grad
,
fwd_p2p_handles
,
bwd_p2p_handles
,
bwd_layer_graph
.
layer_input
.
grad
))
def
transformer_layer_forward_dense_backward_dense_overlaping
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
=
None
,
bwd_layer_graph
:
LayerGraph
=
None
,
bwd_unperm_a2a_handle
=
None
,
next_bwd_layer_graph
:
LayerGraph
=
None
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
checkpoint
=
False
):
if
checkpoint
:
checkpoint_context
=
torch
.
no_grad
()
else
:
checkpoint_context
=
nullcontext
()
args
=
get_args
()
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
with
checkpoint_context
:
# Atten Fwd
detached_layer_input
=
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
fwd_layer
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
attention_graph
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
,
checkpoint_forward
=
checkpoint
)
# Residual connection.
residual2
=
detached_attention_out
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
fwd_layer
.
norm_ckpt2
.
checkpoint
(
fwd_layer
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
fwd_layer
.
pre_mlp_layernorm
(
detached_attention_out
)
# MLP.
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
,
checkpoint_forward
=
checkpoint
)
mlp_output_with_bias
=
fwd_layer
.
mlp
(
detached_mlp_input
)
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
.
discard_output
()
mlp_output_with_bias
[
0
].
register_hook
(
fwd_layer
.
norm_ckpt2
.
recompute
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
fwd_layer
.
bias_dropout_add_exec_handler
():
hidden_states
=
fwd_layer
.
mlp_bda
(
fwd_layer
.
training
,
fwd_layer
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual2
,
fwd_layer
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
# handle fwd p2p communication
next_iter_input_tensor
,
fwd_p2p_handles
=
None
,
None
fwd_pp_comm_params
=
pp_comm_params
if
is_p2p_comm_needed
(
fwd_pp_comm_params
):
next_iter_input_tensor
,
fwd_p2p_handles
=
p2p_comm_helper
(
fwd_pp_comm_params
,
output
)
# Detach backward into dx/dw
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
unperm2_graph
,
bwd_layer_output_grad
,
keep_grad
=
True
)
# keep for dw
run_graph_backward
(
bwd_layer_graph
.
pre_mlp_layernorm_graph
,
keep_graph
=
True
)
run_graph_backward
(
bwd_layer_graph
.
attn_graph
,
keep_grad
=
True
)
WeightGradStore
.
end_decouple
()
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
run_graph_backward
(
next_bwd_layer_graph
.
unperm2_graph
,
bwd_layer_graph
.
layer_input
.
grad
,
keep_graph
=
True
)
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
bwd_layer_graph
.
layer_input
.
grad
,
None
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
_
,
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
async_all_to_all
(
next_bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
next_bwd_layer_graph
.
output_splits
,
next_bwd_layer_graph
.
input_splits
,
ep_group
)
# handle bwd p2p communication
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
None
,
None
if
is_p2p_comm_needed
(
bwd_pp_comm_params
):
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
p2p_comm_helper
(
bwd_pp_comm_params
,
bwd_layer_graph
.
layer_input
.
grad
)
WeightGradStore
.
pop
()
saved_tensors
=
(
(
attention_graph
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
None
,
None
),
(
None
,
None
),
# perm1 graph
(
None
,
None
),
(
None
,
None
),
# perm2 graph
(
None
,
None
),
# grouped mlp graph
(
None
,
None
),
# unperm1 graph
(
None
,
None
),
(
output
,
None
),
# unperm2 graph
(
None
,
None
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
[],
None
,
None
,
fwd_layer
,
checkpointed
=
checkpoint
)
for
tensor
in
bwd_layer_graph
.
recompute_needed_tensors
:
if
tensor
is
not
None
:
tensor
.
untyped_storage
().
resize_
(
0
)
return
(
output
,
context
,
graph
,
(
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
),
P2PCommOutput
(
next_iter_input_tensor
,
next_iter_output_tensor_grad
,
fwd_p2p_handles
,
bwd_p2p_handles
,
bwd_layer_graph
.
layer_input
.
grad
))
def
transformer_layer_forward_moe_backward_moe_overlaping
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
=
None
,
bwd_layer_graph
:
LayerGraph
=
None
,
bwd_unperm_a2a_handle
=
None
,
next_bwd_layer_graph
:
LayerGraph
=
None
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
checkpoint
=
False
):
if
checkpoint
:
checkpoint_context
=
torch
.
no_grad
()
else
:
checkpoint_context
=
nullcontext
()
args
=
get_args
()
ep_group
=
parallel_state
.
get_expert_model_parallel_group
()
if
args
.
moe_tp_extend_ep
:
ep_group
=
parallel_state
.
get_tensor_and_expert_parallel_group
()
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
tp_group
=
parallel_state
.
get_tensor_model_parallel_group
()
use_shared_experts
=
hasattr
(
fwd_layer
.
mlp
,
'shared_experts'
)
and
fwd_layer
.
mlp
.
shared_experts
is
not
None
recomp_norm
=
getattr
(
args
,
'recompute_norm'
,
False
)
bwd_dispached_input
,
bwd_fc1_out
,
bwd_act_out
,
bwd_probs
,
bwd_indices
,
global_input_tokens_local_experts_indices
=
bwd_layer_graph
.
recompute_needed_tensors
a2a_hooked_on_attention
=
getattr
(
fwd_layer
.
self_attention
,
'a2a_hooked_on_attention'
,
False
)
# Unperm2 Bwd
# check if backward unpermutation alltoall is launched at bwd layer before
if
bwd_unperm_a2a_handle
is
None
:
run_graph_backward
(
bwd_layer_graph
.
unperm2_graph
,
bwd_layer_output_grad
)
# Async Unperm A2A
if
tp_size
>
1
and
a2a_hooked_on_attention
:
set_async_alltoall_inputs
(
bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
bwd_layer_graph
.
output_splits
,
bwd_layer_graph
.
input_splits
,
ep_group
)
else
:
_
,
unperm1_out_grad
,
bwd_unperm_a2a_handle
=
async_all_to_all
(
bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
bwd_layer_graph
.
output_splits
,
bwd_layer_graph
.
input_splits
,
ep_group
)
else
:
unperm1_out_grad
=
bwd_layer_output_grad
if
args
.
moe_zero_memory
==
'level0'
:
with
torch
.
no_grad
():
bwd_input_before_perm1
=
bwd_layer_graph
.
pre_mlp_layernorm_graph
[
0
]
def
recomp_token_permutation1
(
hidden_states
,
indices
):
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
permutated_local_input_tokens
,
_
=
permute
(
hidden_states
,
indices
)
return
permutated_local_input_tokens
bwd_perm1_out
=
recomp_token_permutation1
(
bwd_input_before_perm1
,
bwd_indices
)
with
checkpoint_context
:
# Residual connection.
detached_layer_input
=
detach_tensor
(
hidden_states
)
residual1
=
detached_layer_input
# input_layernorm + AttentionForward
hidden_states
=
attention_forward
(
fwd_layer
,
detached_layer_input
,
residual1
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
recompute_norm
=
recomp_norm
)
if
bwd_unperm_a2a_handle
is
None
and
tp_size
>
1
and
a2a_hooked_on_attention
:
unperm1_out_grad
,
bwd_unperm_a2a_handle
=
get_async_alltoall_outputs
()
attention_graph
,
detached_attention_out
=
hidden_states
,
detach_tensor
(
hidden_states
)
# Residual connection.
residual2
=
detached_attention_out
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
=
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
fwd_layer
.
norm_ckpt2
.
checkpoint
(
fwd_layer
.
pre_mlp_layernorm
,
False
,
detached_attention_out
)
else
:
pre_mlp_layernorm_output
=
fwd_layer
.
pre_mlp_layernorm
(
detached_attention_out
)
# MLP.
detached_mlp_input
=
detach_tensor
(
pre_mlp_layernorm_output
)
probs
,
indices
=
router_forward
(
fwd_layer
,
detached_mlp_input
)
if
tp_size
>
1
and
use_shared_experts
:
# launch tp comm here and wait last aync comm finish
_
,
shared_experts_input
,
shared_experts_allgather_handle
=
async_all_gather
(
detached_mlp_input
,
tp_group
,
event
=
bwd_unperm_a2a_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
bwd_unperm_a2a_handle
else
None
,
is_use_get_global_memory_buffer
=
True
)
AG_SHARED_EXPERTS_INPUTS
.
append
((
shared_experts_input
,
shared_experts_allgather_handle
))
else
:
shared_experts_input
,
shared_experts_allgather_handle
=
detached_mlp_input
,
None
# Token Permutation Forward
probs_detached
=
detach_tensor
(
probs
)
perm1_out
,
tokens_per_expert
=
alltoall_token_perm1
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_mlp_input
,
probs_detached
,
indices
)
if
args
.
moe_zero_memory
==
'level0'
or
should_recompute_activation
(
bwd_layer_graph
.
layer
.
layer_number
):
with
torch
.
no_grad
():
recompute_act_out
=
bwd_layer_graph
.
layer
.
mlp
.
experts
.
activation_func
(
bwd_fc1_out
)
bwd_act_out
.
untyped_storage
().
resize_
(
recompute_act_out
.
untyped_storage
().
size
())
bwd_act_out
.
untyped_storage
().
copy_
(
recompute_act_out
.
untyped_storage
())
recompute_act_out
.
untyped_storage
().
resize_
(
0
)
last_comm_handle
=
shared_experts_allgather_handle
if
shared_experts_allgather_handle
else
bwd_unperm_a2a_handle
if
args
.
moe_zero_memory
==
'level0'
:
_
,
bwd_perm_a2a_out
,
bwd_recomp_perm_a2a_handle
=
async_all_to_all
(
bwd_perm1_out
,
bwd_layer_graph
.
output_splits
,
bwd_layer_graph
.
input_splits
,
ep_group
,
event
=
last_comm_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
last_comm_handle
else
None
)
last_comm_handle
=
bwd_recomp_perm_a2a_handle
with
checkpoint_context
:
_
,
perm_a2a_out
,
perm_a2a_handle
=
async_all_to_all
(
perm1_out
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
ep_group
,
event
=
last_comm_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
last_comm_handle
else
None
)
last_comm_handle
=
perm_a2a_handle
with
checkpoint_context
:
shared_expert_output
=
None
if
use_shared_experts
:
if
shared_experts_allgather_handle
is
not
None
:
shared_experts_allgather_handle
.
wait
()
shared_experts_allgather_handle
=
None
shared_expert_output
,
_
=
fwd_layer
.
mlp
.
shared_experts
(
detached_mlp_input
)
if
tp_size
>
1
:
# launch tp comm after permf a2a and wait until shared experts computation finish.
share_experts_graph
,
shared_expert_output
,
rs_shared_experts_handle
=
async_reduce_scatter
(
shared_expert_output
,
tp_group
,
event
=
last_comm_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
last_comm_handle
else
None
)
last_comm_handle
=
rs_shared_experts_handle
else
:
share_experts_graph
=
shared_expert_output
rs_shared_experts_handle
=
None
if
recomp_norm
:
fwd_layer
.
norm_ckpt2
.
discard_output
()
bwd_unperm_a2a_handle
.
wait
()
bwd_unperm_a2a_handle
=
None
run_graph_backward
(
bwd_layer_graph
.
unperm1_graph
,
unperm1_out_grad
)
unperm1_out_grad
.
untyped_storage
().
resize_
(
0
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
grouped_mlp_graph
,
keep_grad
=
True
)
# keep for dw
WeightGradStore
.
end_decouple
()
run_graph_backward
(
bwd_layer_graph
.
perm2_graph
,
keep_graph
=
True
)
# keep for dw
perm_a2a_handle
.
wait
()
perm_a2a_handle
=
None
perm1_out
.
untyped_storage
().
resize_
(
0
)
_
,
perm1_out_grad
,
bwd_perm_a2a_handle
=
async_all_to_all
(
bwd_layer_graph
.
perm_a2a_graph
[
1
].
grad
,
bwd_layer_graph
.
input_splits
,
bwd_layer_graph
.
output_splits
,
ep_group
,
event
=
last_comm_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
last_comm_handle
else
None
)
last_comm_handle
=
bwd_perm_a2a_handle
# launch shared expert grad allgather here
if
tp_size
>
1
:
_
,
backward_ag_shared
,
backward_ag_shared_handle
=
async_all_gather
(
bwd_layer_graph
.
shared_experts_graph
[
1
].
grad
,
tp_group
,
event
=
last_comm_handle
,
stream
=
torch
.
npu
.
current_stream
()
if
last_comm_handle
else
None
)
else
:
backward_ag_shared
=
bwd_layer_graph
.
shared_experts_graph
[
1
].
grad
backward_ag_shared_handle
=
None
# Grouped MLP dw computation
if
args
.
moe_zero_memory
==
'level0'
:
# restore fc1 input for dw computation
with
torch
.
no_grad
():
bwd_recomp_perm_a2a_handle
.
wait
()
bwd_recomp_perm_a2a_handle
=
None
recompute_fc1_input
,
_
=
permute
(
bwd_perm_a2a_out
,
global_input_tokens_local_experts_indices
)
bwd_perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
bwd_dispached_input
.
untyped_storage
().
resize_
(
recompute_fc1_input
.
untyped_storage
().
size
())
bwd_dispached_input
.
untyped_storage
().
copy_
(
recompute_fc1_input
.
untyped_storage
())
recompute_fc1_input
.
untyped_storage
().
resize_
(
0
)
WeightGradStore
.
pop
()
with
checkpoint_context
:
detached_perm_a2a_out
=
detach_tensor
(
perm_a2a_out
)
dispached_input
=
alltoall_token_perm2
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_perm_a2a_out
)
perm_a2a_out
.
untyped_storage
().
resize_
(
0
)
# Grouped MLP Forward
detached_dispached_input
=
detach_tensor
(
dispached_input
)
(
expert_output
,
fc1_output
,
act_out
),
_
=
fwd_layer
.
mlp
.
experts
(
detached_dispached_input
,
tokens_per_expert
)
if
args
.
moe_zero_memory
==
'level0'
:
dispached_input
.
untyped_storage
().
resize_
(
0
)
recompute_needed_tensors
=
[
dispached_input
,
fc1_output
,
act_out
,
probs
,
indices
,
fwd_layer
.
mlp
.
token_dispatcher
.
global_input_tokens_local_experts_indices
]
else
:
if
should_recompute_activation
(
fwd_layer
.
layer_number
):
recompute_needed_tensors
=
[
None
,
fc1_output
,
act_out
,
None
,
None
,
None
]
else
:
recompute_needed_tensors
=
[
None
,
None
,
None
,
None
,
None
,
None
]
detached_expert_output
=
detach_tensor
(
expert_output
)
# Token Unpermutaion Forward
unperm1_out
=
alltoall_token_unperm1
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_expert_output
,
None
)
expert_output
.
untyped_storage
().
resize_
(
0
)
if
rs_shared_experts_handle
is
not
None
:
rs_shared_experts_handle
.
wait
()
rs_shared_experts_handle
=
None
share_experts_graph
.
untyped_storage
().
resize_
(
0
)
bwd_perm_a2a_handle
.
wait
()
bwd_perm_a2a_handle
=
None
if
backward_ag_shared_handle
is
not
None
:
# ensure tp comm is not overlaped with alltoall comm
backward_ag_shared_handle
.
wait
()
backward_ag_shared_handle
=
None
# move shared experts backward before unpermF all2all to avoid tp comm colision.
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
shared_experts_graph
,
backward_ag_shared
,
keep_grad
=
True
)
# dw computation
WeightGradStore
.
end_decouple
()
with
checkpoint_context
:
# launch async all2all in the middle of attention graph backward
if
tp_size
>
1
and
a2a_hooked_on_attention
:
set_async_alltoall_inputs
(
unperm1_out
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
ep_group
)
else
:
_
,
unperm_a2a_out
,
unperm_a2a_handle
=
async_all_to_all
(
unperm1_out
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
ep_group
)
run_graph_backward
(
bwd_layer_graph
.
perm1_graph
,
perm1_out_grad
)
perm1_out_grad
.
untyped_storage
().
resize_
(
0
)
run_graph_backward
(
bwd_layer_graph
.
router_graph
)
run_graph_backward
(
bwd_layer_graph
.
pre_mlp_layernorm_graph
,
keep_graph
=
True
)
WeightGradStore
.
start_decouple
()
run_graph_backward
(
bwd_layer_graph
.
attn_graph
,
keep_grad
=
True
)
WeightGradStore
.
end_decouple
()
if
tp_size
>
1
and
a2a_hooked_on_attention
:
unperm_a2a_out
,
unperm_a2a_handle
=
get_async_alltoall_outputs
()
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
run_graph_backward
(
next_bwd_layer_graph
.
unperm2_graph
,
bwd_layer_graph
.
layer_input
.
grad
,
keep_graph
=
True
)
unperm_a2a_handle
.
wait
()
unperm_a2a_handle
=
None
unperm1_out
.
untyped_storage
().
resize_
(
0
)
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
bwd_layer_graph
.
layer_input
.
grad
,
None
if
next_bwd_layer_graph
is
not
None
and
getattr
(
next_bwd_layer_graph
,
'is_moe_layer'
,
False
):
_
,
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
=
async_all_to_all
(
next_bwd_layer_graph
.
unperm_a2a_graph
[
1
].
grad
,
next_bwd_layer_graph
.
output_splits
,
next_bwd_layer_graph
.
input_splits
,
ep_group
)
with
checkpoint_context
:
detached_unperm_a2a_out
=
detach_tensor
(
unperm_a2a_out
)
route_expert_output
,
_
=
alltoall_token_unperm2
(
fwd_layer
.
mlp
.
token_dispatcher
,
detached_unperm_a2a_out
)
if
hasattr
(
fwd_layer
.
mlp
,
'shared_experts'
)
and
fwd_layer
.
mlp
.
shared_experts
is
not
None
:
detached_shared_expert_output
=
detach_tensor
(
shared_expert_output
)
mlp_output
=
route_expert_output
+
detached_shared_expert_output
shared_expert_output
.
untyped_storage
().
resize_
(
0
)
else
:
detached_shared_expert_output
=
None
share_experts_graph
=
None
mlp_output
=
route_expert_output
if
recomp_norm
:
mlp_output
.
register_hook
(
fwd_layer
.
norm_ckpt2
.
recompute
)
with
fwd_layer
.
bias_dropout_add_exec_handler
():
hidden_states
=
fwd_layer
.
mlp_bda
(
fwd_layer
.
training
,
fwd_layer
.
config
.
bias_dropout_fusion
)(
(
mlp_output
,
None
),
residual2
,
fwd_layer
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
# handle fwd p2p communication
next_iter_input_tensor
,
fwd_p2p_handles
=
None
,
None
fwd_pp_comm_params
=
pp_comm_params
if
is_p2p_comm_needed
(
fwd_pp_comm_params
):
next_iter_input_tensor
,
fwd_p2p_handles
=
p2p_comm_helper
(
fwd_pp_comm_params
,
output
)
# handle bwd p2p communication
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
None
,
None
if
is_p2p_comm_needed
(
bwd_pp_comm_params
):
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
p2p_comm_helper
(
bwd_pp_comm_params
,
bwd_layer_graph
.
layer_input
.
grad
)
WeightGradStore
.
pop
()
saved_tensors
=
(
(
attention_graph
,
detached_attention_out
),
(
pre_mlp_layernorm_output
,
detached_mlp_input
),
(
probs
,
probs_detached
),
(
perm1_out
,
None
),
# perm1 graph
(
None
,
detached_perm_a2a_out
),
(
dispached_input
,
detached_dispached_input
),
# perm2 graph
(
expert_output
,
detached_expert_output
),
# grouped mlp graph
(
unperm1_out
,
None
),
# unperm1 graph
(
None
,
detached_unperm_a2a_out
),
(
output
,
None
),
# unperm2 graph
(
share_experts_graph
,
detached_shared_expert_output
),
detached_layer_input
)
graph
=
LayerGraph
(
saved_tensors
,
recompute_needed_tensors
,
fwd_layer
.
mlp
.
token_dispatcher
.
input_splits
,
fwd_layer
.
mlp
.
token_dispatcher
.
output_splits
,
fwd_layer
,
checkpointed
=
checkpoint
)
for
tensor
in
bwd_layer_graph
.
recompute_needed_tensors
:
if
tensor
is
not
None
:
tensor
.
untyped_storage
().
resize_
(
0
)
return
(
output
,
context
,
graph
,
(
next_layer_output_grad
,
next_bwd_unperm_a2a_handle
),
P2PCommOutput
(
next_iter_input_tensor
,
next_iter_output_tensor_grad
,
fwd_p2p_handles
,
bwd_p2p_handles
,
bwd_layer_graph
.
layer_input
.
grad
))
dcu_megatron/core/pipeline_parallel/fb_overlap/transformer_block.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
typing
import
List
from
contextlib
import
nullcontext
from
megatron.training
import
get_args
from
megatron.core.utils
import
make_sharded_tensor_for_checkpoint
,
make_viewless_tensor
from
mindspeed.core.transformer.transformer_block
import
NoopTransformerLayer
from
.modules.utils
import
(
detach_tensor
,
LayerGraph
,
P2PCommParams
)
from
.transformer_layer
import
transformer_layer_backward
def
transformer_block_forward
(
self
,
hidden_states
,
attention_mask
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
):
if
not
self
.
pre_process
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
rng_context
=
nullcontext
()
fp8_context
=
nullcontext
()
assert
not
self
.
config
.
enable_cuda_graph
layer_graphs
=
[]
args
=
get_args
()
with
rng_context
and
fp8_context
:
for
l_no
,
layer
in
enumerate
(
self
.
layers
):
checkpoint
=
False
if
self
.
config
.
recompute_granularity
==
'full'
and
self
.
training
:
if
self
.
config
.
recompute_method
==
'block'
:
recompute_skip_num_layers
=
0
if
self
.
config
.
fp8
and
not
hidden_states
.
requires_grad
:
recompute_skip_num_layers
+=
1
if
(
l_no
>=
recompute_skip_num_layers
and
l_no
<
self
.
config
.
recompute_num_layers
+
recompute_skip_num_layers
):
checkpoint
=
True
if
self
.
config
.
recompute_method
==
'uniform'
:
assert
self
.
config
.
recompute_num_layers
==
1
checkpoint
=
True
hidden_states
,
context
,
saved_graphs
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
checkpoint
=
checkpoint
)
layer_graphs
.
append
(
saved_graphs
)
# Final layer norm.
if
self
.
post_process
and
self
.
post_layer_norm
and
self
.
final_layernorm
is
not
None
:
detached_hidden_states
=
detach_tensor
(
hidden_states
)
layer_graphs
[
-
1
].
unperm2_graph
=
(
layer_graphs
[
-
1
].
unperm2_graph
[
0
],
detached_hidden_states
)
hidden_states
=
self
.
final_layernorm
(
detached_hidden_states
)
return
(
hidden_states
,
layer_graphs
)
def
transformer_block_backward
(
block_output_grad
,
layer_graphs
:
List
[
LayerGraph
],
):
# should call backward fisrt for final_layernorm and postprocess grad
layer_output_grad
=
block_output_grad
while
len
(
layer_graphs
)
>
0
:
layer_graph
=
layer_graphs
.
pop
(
-
1
)
layer_output_grad
=
transformer_layer_backward
(
layer_output_grad
,
layer_graph
)
return
layer_output_grad
def
transformer_block_forward_backward_overlaping
(
fwd_block
,
hidden_states
,
attention_mask
,
bwd_block_output_grad
,
bwd_block_graphs
:
List
[
LayerGraph
],
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
):
if
not
fwd_block
.
pre_process
:
# See set_input_tensor()
hidden_states
=
fwd_block
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
rng_context
=
nullcontext
()
fp8_context
=
nullcontext
()
assert
not
fwd_block
.
config
.
enable_cuda_graph
fwd_layer_graphs
=
[]
bwd_layer_output_grad
=
bwd_block_output_grad
bwd_unperm_a2a_handle
=
None
fwd_hidden_states
,
fwd_context
=
hidden_states
,
context
with
(((
rng_context
and
fp8_context
))):
for
l_no
,
fwd_layer
in
enumerate
(
fwd_block
.
layers
):
checkpoint
=
False
if
fwd_block
.
config
.
recompute_granularity
==
'full'
and
fwd_block
.
training
:
if
fwd_block
.
config
.
recompute_method
==
'block'
:
recompute_skip_num_layers
=
0
if
fwd_block
.
config
.
fp8
and
not
hidden_states
.
requires_grad
:
recompute_skip_num_layers
+=
1
if
(
l_no
>=
recompute_skip_num_layers
and
l_no
<
fwd_block
.
config
.
recompute_num_layers
+
recompute_skip_num_layers
):
checkpoint
=
True
if
fwd_block
.
config
.
recompute_method
==
'uniform'
:
assert
fwd_block
.
config
.
recompute_num_layers
==
1
checkpoint
=
True
bwd_layer_graph
=
bwd_block_graphs
.
pop
(
-
1
)
cur_p2p_params
=
pp_comm_params
cur_bwd_p2p_params
=
bwd_pp_comm_params
if
l_no
!=
len
(
fwd_block
.
layers
)
-
1
or
len
(
bwd_block_graphs
)
>
0
:
# no need to excute pp communication in the intermediate layers
cur_p2p_params
=
P2PCommParams
()
cur_bwd_p2p_params
=
P2PCommParams
()
next_bwd_layer_graph
=
None
if
(
len
(
bwd_block_graphs
)
>
0
and
not
bwd_block_graphs
[
-
1
].
checkpointed
and
l_no
!=
len
(
fwd_block
.
layers
)
-
1
and
not
isinstance
(
fwd_block
.
layers
[
l_no
+
1
],
NoopTransformerLayer
)
):
next_bwd_layer_graph
=
bwd_block_graphs
[
-
1
]
fwd_hidden_states
,
fwd_context
,
fwd_layer_graph
,
\
(
bwd_layer_output_grad
,
bwd_unperm_a2a_handle
),
\
pp_comm_output
=
\
fwd_layer
(
fwd_hidden_states
,
attention_mask
,
bwd_layer_output_grad
,
bwd_layer_graph
=
bwd_layer_graph
,
bwd_unperm_a2a_handle
=
bwd_unperm_a2a_handle
,
next_bwd_layer_graph
=
next_bwd_layer_graph
,
context
=
fwd_context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
pp_comm_params
=
cur_p2p_params
,
bwd_pp_comm_params
=
cur_bwd_p2p_params
,
checkpoint
=
checkpoint
)
fwd_layer_graphs
.
append
(
fwd_layer_graph
)
# Final layer norm.
if
fwd_block
.
post_process
and
fwd_block
.
post_layer_norm
and
fwd_block
.
final_layernorm
is
not
None
:
detached_hidden_states
=
detach_tensor
(
fwd_hidden_states
)
fwd_layer_graphs
[
-
1
].
unperm2_graph
=
(
fwd_layer_graphs
[
-
1
].
unperm2_graph
[
0
],
detached_hidden_states
)
fwd_hidden_states
=
fwd_block
.
final_layernorm
(
detached_hidden_states
)
return
(
fwd_hidden_states
,
fwd_layer_graphs
),
bwd_layer_output_grad
,
pp_comm_output
dcu_megatron/core/pipeline_parallel/fb_overlap/transformer_layer.py
0 → 100644
View file @
770fa304
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
contextlib
import
nullcontext
import
torch
from
mindspeed.core.transformer.transformer_block
import
NoopTransformerLayer
from
.modules.utils
import
(
NoopLayerGraph
,
LayerGraph
,
is_p2p_comm_needed
,
p2p_comm_helper
,
P2PCommOutput
,
P2PCommParams
)
from
.overlap_funcs
import
(
transformer_layer_forward_moe
,
transformer_layer_forward_dense
,
transformer_layer_forward_noop
,
transformer_layer_backward_moe
,
transformer_layer_backward_dense
,
transformer_layer_backward_noop
,
transformer_layer_forward_moe_backward_moe_overlaping
,
transformer_layer_forward_dense_backward_moe_overlaping
,
transformer_layer_forward_moe_backward_dense_overlaping
,
transformer_layer_forward_dense_backward_dense_overlaping
,
)
def
transformer_layer_forward
(
self
,
hidden_states
,
attention_mask
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
use_orig_layer_forward
=
False
,
checkpoint
=
False
):
if
checkpoint
:
checkpoint_context
=
torch
.
no_grad
()
else
:
checkpoint_context
=
nullcontext
()
with
checkpoint_context
:
layer_forward_func
=
None
if
use_orig_layer_forward
:
from
mindspeed.core.pipeline_parallel.fp_overlap.megatron_adaptor
import
get_orig_transformer_layer_forward_func
# for mtp transformer layer forward
layer_forward_func
=
get_orig_transformer_layer_forward_func
()
return
layer_forward_func
(
self
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
)
elif
isinstance
(
self
,
NoopTransformerLayer
):
layer_forward_func
=
transformer_layer_forward_noop
elif
hasattr
(
self
.
mlp
,
'experts'
):
layer_forward_func
=
transformer_layer_forward_moe
else
:
layer_forward_func
=
transformer_layer_forward_dense
return
layer_forward_func
(
self
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
checkpoint
=
checkpoint
)
def
transformer_layer_backward
(
layer_output_grad
,
layer_graph
):
if
layer_graph
.
checkpointed
:
with
torch
.
enable_grad
():
_
,
_
,
restored_layer_graph
=
transformer_layer_forward
(
layer_graph
.
layer
,
layer_graph
.
layer_input
,
*
layer_graph
.
layer_inputs
,
checkpoint
=
False
)
restored_layer_graph
.
unperm2_graph
=
(
restored_layer_graph
.
unperm2_graph
[
0
],
layer_graph
.
unperm2_graph
[
1
])
layer_graph
=
restored_layer_graph
if
isinstance
(
layer_graph
,
NoopLayerGraph
):
return
transformer_layer_backward_noop
(
layer_output_grad
,
layer_graph
)
elif
layer_graph
.
is_moe_layer
:
return
transformer_layer_backward_moe
(
layer_output_grad
,
layer_graph
)
else
:
return
transformer_layer_backward_dense
(
layer_output_grad
,
layer_graph
)
def
transformer_layer_forward_backward_overlaping
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
=
None
,
bwd_layer_graph
:
LayerGraph
=
None
,
bwd_unperm_a2a_handle
=
None
,
next_bwd_layer_graph
:
LayerGraph
=
None
,
context
=
None
,
context_mask
=
None
,
rotary_pos_emb
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
pp_comm_params
:
P2PCommParams
=
None
,
bwd_pp_comm_params
:
P2PCommParams
=
None
,
use_orig_layer_forward
=
False
,
checkpoint
=
False
):
if
isinstance
(
fwd_layer
,
NoopTransformerLayer
)
or
bwd_layer_graph
is
None
or
isinstance
(
bwd_layer_graph
,
NoopLayerGraph
):
# no f&w overlaping
if
bwd_layer_graph
is
None
:
out
=
transformer_layer_forward
(
fwd_layer
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
use_orig_layer_forward
,
checkpoint
=
checkpoint
)
if
len
(
out
)
>
2
and
checkpoint
:
out
[
2
].
record_layer_inputs
(
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
use_orig_layer_forward
)
return
out
else
:
output
,
context
,
graph
=
transformer_layer_forward
(
fwd_layer
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
use_orig_layer_forward
,
checkpoint
=
checkpoint
)
# handle fwd p2p communication
next_iter_input_tensor
,
fwd_p2p_handles
=
None
,
None
fwd_pp_comm_params
=
pp_comm_params
if
is_p2p_comm_needed
(
fwd_pp_comm_params
):
next_iter_input_tensor
,
fwd_p2p_handles
=
p2p_comm_helper
(
fwd_pp_comm_params
,
output
)
bwd_input_grad
=
transformer_layer_backward
(
bwd_layer_output_grad
,
bwd_layer_graph
)
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
None
,
None
if
bwd_input_grad
is
not
None
:
# handle bwd p2p communication
if
is_p2p_comm_needed
(
bwd_pp_comm_params
):
next_iter_output_tensor_grad
,
bwd_p2p_handles
=
p2p_comm_helper
(
bwd_pp_comm_params
,
bwd_input_grad
)
if
checkpoint
:
graph
.
record_layer_inputs
(
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
use_orig_layer_forward
)
return
(
output
,
context
,
graph
,
(
bwd_input_grad
,
None
),
P2PCommOutput
(
next_iter_input_tensor
,
next_iter_output_tensor_grad
,
fwd_p2p_handles
,
bwd_p2p_handles
,
bwd_input_grad
))
else
:
fb_overlap_func
=
None
if
hasattr
(
fwd_layer
.
mlp
,
'experts'
)
and
bwd_layer_graph
.
is_moe_layer
:
fb_overlap_func
=
transformer_layer_forward_moe_backward_moe_overlaping
elif
hasattr
(
fwd_layer
.
mlp
,
'experts'
)
and
not
bwd_layer_graph
.
is_moe_layer
:
fb_overlap_func
=
transformer_layer_forward_moe_backward_dense_overlaping
elif
not
hasattr
(
fwd_layer
.
mlp
,
'experts'
)
and
bwd_layer_graph
.
is_moe_layer
:
fb_overlap_func
=
transformer_layer_forward_dense_backward_moe_overlaping
elif
not
hasattr
(
fwd_layer
.
mlp
,
'experts'
)
and
not
bwd_layer_graph
.
is_moe_layer
:
fb_overlap_func
=
transformer_layer_forward_dense_backward_dense_overlaping
else
:
raise
AssertionError
(
'Check Layer Spec, f&b overlap func is not supported!'
)
if
bwd_layer_graph
.
checkpointed
:
_
,
_
,
bwd_layer_graph
=
transformer_layer_forward
(
bwd_layer_graph
.
layer
,
bwd_layer_graph
.
layer_input
,
*
bwd_layer_graph
.
layer_inputs
,
checkpoint
=
False
)
out
=
fb_overlap_func
(
fwd_layer
,
hidden_states
,
attention_mask
,
bwd_layer_output_grad
,
bwd_layer_graph
,
bwd_unperm_a2a_handle
,
next_bwd_layer_graph
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
pp_comm_params
,
bwd_pp_comm_params
,
checkpoint
=
checkpoint
)
if
checkpoint
:
out
[
2
].
record_layer_inputs
(
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
use_orig_layer_forward
)
return
out
dcu_megatron/core/pipeline_parallel/fb_overlap/vpp_schedules.py
0 → 100644
View file @
770fa304
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
from
functools
import
partial
import
torch
from
torch.autograd.variable
import
Variable
from
megatron.training
import
get_timers
,
get_args
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel.schedules
import
(
deallocate_output_tensor
,
set_current_microbatch
,
check_first_val_step
,
clear_embedding_activation_buffer
,
finish_embedding_wgrad_compute
,
custom_backward
)
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
(
drain_embedding_wgrad_compute
,
get_attr_wrapped_model
,
get_model_config
,
get_model_type
,
)
from
.gpt_model
import
gpt_model_backward
from
.modules.utils
import
P2PCommParams
LOSS_BACKWARD_SCALE
=
torch
.
tensor
(
1.0
)
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
extra_block_kwargs
=
None
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
extra_block_kwargs
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
checkpoint_activations_microbatch
,
extra_block_kwargs
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
next_info
=
None
if
isinstance
(
output_tensor
,
tuple
):
# use pp overlaping,
if
len
(
output_tensor
)
==
2
:
output_tensor
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor
,
model_graph
,
next_info
=
output_tensor
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
output_tensor
=
(
output_tensor
,
model_graph
,
next_info
)
if
next_info
is
not
None
else
(
output_tensor
,
model_graph
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
LOSS_BACKWARD_SCALE
)
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
def
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
,
model_graph
=
None
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
and
model_graph
is
None
:
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
if
config
.
deallocate_pipeline_outputs
:
if
model_graph
is
None
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
layer_output_grad
=
gpt_model_backward
(
output_tensor_grad
[
0
],
model_graph
)
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
[]
if
model_graph
is
not
None
:
input_tensor_grad
.
append
(
layer_output_grad
)
else
:
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
forward_step_vpp_overlap
(
data_iterator
,
model
,
extra_block_kwargs
=
None
):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
from
pretrain_gpt
import
get_batch
,
loss_func
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
extra_block_kwargs
is
not
None
:
# excute forward backward overlaping
output_tensor
,
model_graph
,
pp_comm_output
=
\
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
extra_block_kwargs
=
extra_block_kwargs
)
return
(
output_tensor
,
model_graph
,
pp_comm_output
),
partial
(
loss_func
,
loss_mask
)
else
:
output_tensor
,
model_graph
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
(
output_tensor
,
model_graph
),
partial
(
loss_func
,
loss_mask
)
def
forward_backward_pipelining_with_interleaving
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert
isinstance
(
model
,
list
),
"interleaved pipeline parallelism expected model chunking"
assert
all
(
isinstance
(
chunk
,
torch
.
nn
.
Module
)
for
chunk
in
model
),
"invalid model chunking"
assert
isinstance
(
data_iterator
,
list
),
"interleaved pipeline parallelism expected each model chunk to have a data iterator"
# should overide forward step func with forward_step_vpp_overlap
forward_step_func
=
forward_step_vpp_overlap
config
=
get_model_config
(
model
[
0
])
if
config
.
overlap_p2p_comm
and
config
.
batch_p2p_comm
:
raise
ValueError
(
"Can not use both overlap_p2p_comm and batch_p2p_comm"
)
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
isinstance
(
no_sync_func
,
list
):
def
multi_no_sync
():
stack
=
contextlib
.
ExitStack
()
for
model_chunk_no_sync_func
in
config
.
no_sync_func
:
stack
.
enter_context
(
model_chunk_no_sync_func
())
return
stack
no_sync_func
=
multi_no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
if
config
.
grad_sync_func
is
not
None
and
not
isinstance
(
config
.
grad_sync_func
,
list
):
config
.
grad_sync_func
=
[
config
.
grad_sync_func
for
_
in
model
]
if
config
.
param_sync_func
is
not
None
and
not
isinstance
(
config
.
param_sync_func
,
list
):
config
.
param_sync_func
=
[
config
.
param_sync_func
for
_
in
model
]
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Model chunk IDs with synchronized grads
synchronized_model_chunks
=
set
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
model_graphs
=
[[]
for
_
in
range
(
len
(
model
))]
logits_inputs
=
[]
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
forward_data_store
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
num_microbatches
%
pipeline_parallel_size
!=
0
:
msg
=
f
'number of microbatches (
{
num_microbatches
}
) is not divisible by '
msg
+=
f
'pipeline-model-parallel-size (
{
pipeline_parallel_size
}
) '
msg
+=
'when using interleaved schedule'
raise
RuntimeError
(
msg
)
model_type
=
get_model_type
(
model
[
0
])
if
model_type
==
ModelType
.
encoder_and_decoder
:
raise
RuntimeError
(
"Interleaving is not supported with an encoder and decoder model."
)
if
decoder_seq_length
is
not
None
and
decoder_seq_length
!=
seq_length
:
raise
RuntimeError
(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape
=
[
seq_length
,
micro_batch_size
,
config
.
hidden_size
]
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
total_num_microbatches
=
num_microbatches
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
total_num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
num_microbatches
==
pipeline_parallel_size
:
num_warmup_microbatches
=
total_num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
total_num_microbatches
)
# add one more warmup microbatches for 1f1b overlaping
num_warmup_microbatches
+=
1
num_microbatches_remaining
=
total_num_microbatches
-
num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
# Synchronize params for first two model chunks
if
config
.
param_sync_func
is
not
None
:
config
.
param_sync_func
[
0
](
model
[
0
].
parameters
())
config
.
param_sync_func
[
1
](
model
[
1
].
parameters
())
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
get_microbatch_id_in_model_chunk
(
iteration_id
,
forward
):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert
forward
iteration_group_id
=
iteration_id
//
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_model_chunk
=
(
iteration_group_id
*
pipeline_parallel_size
)
+
(
iteration_id
%
pipeline_parallel_size
)
return
microbatch_id_in_model_chunk
def
is_first_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
0
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
0
else
:
return
False
def
is_last_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
num_microbatch_groups
-
1
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
pipeline_parallel_size
-
1
else
:
return
False
def
forward_step_helper
(
microbatch_id
,
current_microbatch
,
checkpoint_activations_microbatch
,
extra_block_kwargs
=
None
,
backward_k
=
None
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
param_sync_func
is
not
None
:
param_sync_microbatch_id
=
microbatch_id
+
pipeline_parallel_rank
if
(
param_sync_microbatch_id
<
total_num_microbatches
and
is_first_microbatch_for_model_chunk
(
param_sync_microbatch_id
)
):
param_sync_chunk_id
=
get_model_chunk_id
(
param_sync_microbatch_id
,
forward
=
True
)
+
1
if
1
<
param_sync_chunk_id
<
num_model_chunks
:
config
.
param_sync_func
[
param_sync_chunk_id
](
model
[
param_sync_chunk_id
].
parameters
()
)
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch_for_model_chunk
(
microbatch_id
),
),
current_microbatch
=
current_microbatch
,
extra_block_kwargs
=
extra_block_kwargs
)
if
isinstance
(
output_tensor
,
tuple
):
if
len
(
output_tensor
)
==
2
:
output_tensor_
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor_
,
model_graph
,
pp_comm_output
=
output_tensor
if
parallel_state
.
is_pipeline_last_stage
():
logits_inputs
.
append
(
model_graph
.
layer_graphs
[
-
1
].
unperm2_graph
[
1
])
model_graphs
[
model_chunk_id
].
append
(
model_graph
)
else
:
output_tensor_
=
output_tensor
output_tensors
[
model_chunk_id
].
append
(
output_tensor_
)
if
backward_k
is
not
None
:
backward_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
input_tensors
[
backward_chunk_id
].
pop
(
0
)
output_tensors
[
backward_chunk_id
].
pop
(
0
)
output_tensor_grads
[
backward_chunk_id
].
pop
(
0
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
backward_step_helper
(
microbatch_id
,
logits_bwd
=
False
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
microbatch_id
):
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
if
not
logits_bwd
:
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
model_chunk_id
].
pop
(
0
)
else
:
input_tensor
=
logits_inputs
.
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
][
0
]
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
][
0
]
model_graph
=
None
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
,
model_graph
)
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_microbatch_id
=
microbatch_id
-
pipeline_parallel_rank
if
grad_sync_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
[
grad_sync_chunk_id
](
model
[
grad_sync_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
return
input_tensor_grad
def
check_pipeline_stage
(
forward_k
,
backward_k
):
send_next
=
not
(
get_model_chunk_id
(
forward_k
,
forward
=
True
)
==
num_model_chunks
-
1
and
pipeline_parallel_rank
==
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
1
)
send_prev
=
not
(
get_model_chunk_id
(
backward_k
,
forward
=
False
)
==
0
and
pipeline_parallel_rank
==
0
)
recv_prev
=
not
(
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
==
0
and
pipeline_parallel_rank
==
0
)
if
forward_k
+
1
>=
total_num_microbatches
:
recv_prev
=
False
recv_next
=
not
(
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
==
num_model_chunks
-
1
and
pipeline_parallel_rank
==
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
1
)
return
P2PCommParams
(
send_next
=
send_next
,
recv_prev
=
recv_prev
),
P2PCommParams
(
send_prev
=
send_prev
,
recv_next
=
recv_next
)
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
fwd_wait_handles
=
None
bwd_wait_handles
=
None
P2PCommParams
.
tensor_shape
=
tensor_shape
P2PCommParams
.
config
=
config
for
k
in
range
(
num_warmup_microbatches
):
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
current_microbatch
=
get_microbatch_id_in_model_chunk
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
,
current_microbatch
,
checkpoint_activations_microbatch
,
backward_k
=
None
)
if
isinstance
(
output_tensor
,
tuple
):
# use pp overlaping,
if
len
(
output_tensor
)
==
2
:
output_tensor
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor
,
model_graph
,
pp_comm_output
=
output_tensor
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
total_num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
not
config
.
overlap_p2p_comm
:
if
isinstance
(
output_tensor
,
tuple
):
if
len
(
output_tensor
)
==
2
:
output_tensor
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor
,
model_graph
,
pp_comm_output
=
output_tensor
if
parallel_state
.
is_pipeline_last_stage
():
model_graph
,
logits_input
=
model_graph
logits_input
.
append
(
logits_input
)
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
else
:
input_tensor
,
fwd_wait_handles
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
output_tensor_grad
,
bwd_wait_handles
,
)
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
forward_k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
# 按照绝对mbid 判断chunk无需修改
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
current_microbatch
=
get_microbatch_id_in_model_chunk
(
forward_k
,
forward
=
True
)
if
config
.
overlap_p2p_comm
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
extra_block_kwargs
=
{}
backward_k
=
k
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
input_tensor_grad
=
backward_step_helper
(
backward_k
,
logits_bwd
=
True
)
assert
input_tensor_grad
is
not
None
,
"logits backward should not be None"
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
input_tensor_grad
)
else
:
# input_tensor_grad通过pp通信获得
output_tensor_grad
=
output_tensor_grads
[
backward_model_chunk_id
][
0
]
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
output_tensor_grad
)
fwd_pp_comm_params
,
bwd_pp_comm_params
=
check_pipeline_stage
(
forward_k
,
backward_k
)
extra_block_kwargs
.
setdefault
(
'bwd_model_graph'
,
model_graphs
[
backward_model_chunk_id
].
pop
(
0
))
extra_block_kwargs
.
setdefault
(
'pp_comm_params'
,
fwd_pp_comm_params
)
extra_block_kwargs
.
setdefault
(
'bwd_pp_comm_params'
,
bwd_pp_comm_params
)
output_tensor
=
forward_step_helper
(
forward_k
,
current_microbatch
,
checkpoint_activations_microbatch
,
extra_block_kwargs
,
backward_k
=
backward_k
)
output_tensor
,
model_graph
,
pp_comm_output
=
output_tensor
input_tensor
,
fwd_wait_handles
=
pp_comm_output
.
input_tensor
,
pp_comm_output
.
fwd_wait_handles
output_tensor_grad
,
bwd_wait_handles
=
pp_comm_output
.
output_tensor_grad
,
pp_comm_output
.
bwd_wait_handles
if
fwd_pp_comm_params
.
recv_prev
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
bwd_pp_comm_params
.
recv_next
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
config
.
overlap_p2p_comm
and
bwd_wait_handles
is
not
None
:
for
wait_handle
in
bwd_wait_handles
:
wait_handle
.
wait
()
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
)
)
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
chunk_id
=
get_model_chunk_id
(
k
,
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
input_tensor_grad
=
backward_step_helper
(
k
,
logits_bwd
=
True
)
output_tensor_grads
[
chunk_id
].
append
(
input_tensor_grad
)
output_tensors
[
chunk_id
].
pop
(
0
)
output_tensors
[
chunk_id
].
append
(
None
)
# dummy output tensors
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
total_num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
)
)
# Launch any remaining grad reductions.
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
for
model_chunk_id
in
range
(
num_model_chunks
):
if
model_chunk_id
not
in
synchronized_model_chunks
:
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute
(
config
,
embedding_module
)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
dcu_megatron/core/pipeline_parallel/schedules.py
0 → 100644
View file @
770fa304
import
torch
from
functools
import
wraps
from
dcu_megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
def
forward_step_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
,
):
output
,
num_tokens
=
fn
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
)
if
not
isinstance
(
input_tensor
,
list
):
# unwrap_output_tensor True
output_tensor
=
output
else
:
output_tensor
=
output
[
0
]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
return
output
,
num_tokens
return
wrapper
\ No newline at end of file
dcu_megatron/core/tensor_parallel/__init__.py
View file @
770fa304
from
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
770fa304
import
os
import
socket
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
try
:
...
...
@@ -10,35 +9,19 @@ except ImportError:
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.utils
import
(
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.utils
import
prepare_input_tensors_for_wgrad_compute
from
megatron.core.tensor_parallel.mappings
import
(
_reduce
,
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
...
...
@@ -47,8 +30,6 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd
,
custom_bwd
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
from
dcu_megatron.core.utils
import
is_flux_min_version
...
...
@@ -60,109 +41,6 @@ except ImportError:
_grad_accum_fusion_available
=
False
def
vocab_parallel_embedding_init
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
*
,
init_method
:
Callable
,
reduce_scatter_embeddings
:
bool
=
False
,
config
:
ModelParallelConfig
,
skip_weight_param_allocation
:
bool
=
False
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
self
.
embedding_dim
=
embedding_dim
self
.
reduce_scatter_embeddings
=
reduce_scatter_embeddings
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
(
self
.
vocab_start_index
,
self
.
vocab_end_index
)
=
(
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
,
)
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
self
.
vocab_start_index
self
.
deterministic_mode
=
config
.
deterministic_mode
# Allocate weights and initialize.
if
not
skip_weight_param_allocation
:
if
config
.
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
config
.
params_dtype
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
params_dtype
=
config
.
params_dtype
,
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
else
:
self
.
weight
=
None
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
vocab_parallel_embedding_forward
(
self
,
input_
,
weight
=
None
):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
if
self
.
deterministic_mode
:
output_parallel
=
weight
[
masked_input
]
else
:
# F.embedding currently has a non-deterministic backward function
output_parallel
=
F
.
embedding
(
masked_input
,
weight
)
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel
=
output_parallel
.
transpose
(
0
,
1
).
contiguous
()
output
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
"""
...
...
dcu_megatron/core/transformer/mtp/mtp_spec.py
deleted
100644 → 0
View file @
8096abd4
import
warnings
from
megatron.core.tensor_parallel
import
ColumnParallelLinear
from
megatron.core.transformer
import
ModuleSpec
from
.multi_token_predictor
import
(
MultiTokenPredicationSubmodules
,
MultiTokenPredictor
)
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TENorm
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
import
apex
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
):
"""
Multi Token Predication Layer Specification.
"""
use_te
=
use_te
&
HAVE_TE
mtp_spec
=
ModuleSpec
(
module
=
MultiTokenPredictor
,
submodules
=
MultiTokenPredicationSubmodules
(
embedding
=
None
,
enorm
=
TENorm
if
use_te
else
LNImpl
,
hnorm
=
TENorm
if
use_te
else
LNImpl
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
transformer_layer
=
transformer_layer
,
final_layernorm
=
TENorm
if
use_te
else
LNImpl
,
output_layer
=
None
,
)
)
return
mtp_spec
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
deleted
100644 → 0
View file @
8096abd4
import
os
import
logging
from
dataclasses
import
dataclass
from
typing
import
Union
,
Optional
,
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
,
InferenceParams
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
megatron.core.fusions.fused_cross_entropy
import
fused_vocab_parallel_cross_entropy
from
megatron.core.transformer
import
ModuleSpec
,
TransformerConfig
,
build_module
from
...tensor_parallel.random
import
CheckpointWithoutOutput
@
dataclass
class
MultiTokenPredicationSubmodules
:
embedding
:
Union
[
ModuleSpec
,
type
]
=
None
output_layer
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
final_layernorm
:
Union
[
ModuleSpec
,
type
]
=
None
class
MultiTokenPredictor
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredicationSubmodules
,
vocab_size
:
int
,
max_sequence_length
:
int
,
layer_number
:
int
=
1
,
hidden_dropout
:
float
=
None
,
pre_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
share_mtp_embedding_and_output_weight
=
True
,
recompute_mtp_norm
=
False
,
recompute_mtp_layer
=
False
,
add_output_layer_bias
=
False
):
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self
.
hidden_dropout
=
hidden_dropout
self
.
hidden_size
=
self
.
config
.
hidden_size
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
position_embedding_type
=
position_embedding_type
# share with main model
self
.
share_mtp_embedding_and_output_weight
=
share_mtp_embedding_and_output_weight
self
.
recompute_layer_norm
=
recompute_mtp_norm
self
.
recompute_mtp_layer
=
recompute_mtp_layer
self
.
add_output_layer_bias
=
add_output_layer_bias
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
self
.
position_embedding_type
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_mtp_embedding_and_output_weight
)
if
self
.
position_embedding_type
==
'rope'
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
hidden_size
+
self
.
hidden_size
,
self
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
is_expert
=
False
,
tp_comm_buffer_name
=
'eh'
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
,
)
if
self
.
submodules
.
final_layernorm
:
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
final_layernorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
else
:
self
.
final_layernorm
=
None
if
self
.
config
.
defer_embedding_wgrad_compute
:
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
share_mtp_embedding_and_output_weight
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
forward
(
self
,
hidden_input_ids
:
Tensor
,
embed_input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
embeding_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input
=
self
.
embedding
(
input_ids
=
embed_input_ids
,
position_ids
=
position_ids
,
weight
=
embeding_weight
,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
rotary_seq_len
=
decoder_input
.
size
(
0
)
if
self
.
config
.
sequence_parallel
:
rotary_seq_len
*=
self
.
config
.
tensor_model_parallel_size
rotary_seq_len
*=
self
.
config
.
context_parallel_size
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
=
CheckpointWithoutOutput
()
enorm_output
=
self
.
enorm_ckpt
.
checkpoint
(
self
.
enorm
,
False
,
decoder_input
)
self
.
hnorm_ckpt
=
CheckpointWithoutOutput
()
hnorm_output
=
self
.
hnorm_ckpt
.
checkpoint
(
self
.
hnorm
,
False
,
hidden_input_ids
)
else
:
enorm_output
=
self
.
enorm
(
decoder_input
)
hnorm_output
=
self
.
hnorm
(
hidden_input_ids
)
# [s, b, h] -> [s, b, 2h]
hidden_states
=
torch
.
concat
(
[
hnorm_output
,
enorm_output
],
dim
=-
1
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
.
discard_output
()
self
.
hnorm_ckpt
.
discard_output
()
hidden_states
.
register_hook
(
self
.
enorm_ckpt
.
recompute
)
hidden_states
.
register_hook
(
self
.
hnorm_ckpt
.
recompute
)
# hidden_states -> [s, b, h]
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
if
self
.
config
.
tensor_model_parallel_size
>
1
:
hidden_states
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
hidden_states
)
if
self
.
config
.
sequence_parallel
:
hidden_states
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
hidden_states
)
if
self
.
recompute_mtp_layer
:
hidden_states
,
context
=
tensor_parallel
.
checkpoint
(
self
.
transformer_layer
,
self
.
config
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
None
,
None
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
)
else
:
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
# Final layer norm.
if
self
.
final_layernorm
is
not
None
:
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
=
CheckpointWithoutOutput
()
finalnorm_output
=
self
.
finalnorm_ckpt
.
checkpoint
(
self
.
final_layernorm
,
False
,
hidden_states
)
else
:
finalnorm_output
=
self
.
final_layernorm
(
hidden_states
)
else
:
finalnorm_output
=
hidden_states
logits
,
_
=
self
.
output_layer
(
finalnorm_output
,
weight
=
output_weight
)
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
.
discard_output
()
logits
.
register_hook
(
self
.
finalnorm_ckpt
.
recompute
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
hidden_states
,
loss
def
compute_language_model_loss
(
self
,
labels
:
Tensor
,
logits
:
Tensor
)
->
Tensor
:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
config
.
cross_entropy_loss_fusion
:
loss
=
fused_vocab_parallel_cross_entropy
(
logits
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
logits
,
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
\ No newline at end of file
dcu_megatron/core/transformer/multi_token_prediction.py
0 → 100755
View file @
770fa304
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
mpu
,
parallel_state
,
tensor_parallel
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.utils
import
replace_prefix_for_sharding
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel
import
(
all_gather_last_dim_from_tensor_parallel_region
,
scatter_to_sequence_parallel_region
,
)
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
,
make_viewless_tensor
SUPPORTED_ATTN_MASK
=
[
AttnMaskType
.
padding
,
AttnMaskType
.
causal
,
AttnMaskType
.
no_mask
,
AttnMaskType
.
padding_causal
,
]
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TEDelayedScaling
,
TENorm
,
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
import
warnings
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
tie_word_embeddings_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
word_emb_weight
:
Tensor
,
word_emb_weight_key
:
str
)
->
None
:
"""tie the embedding of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
word_emb_weight (Tensor): weight of the word embedding.
word_emb_weight_key (str): key of the word embedding in the sharded state dict.
Returns: None, acts in-place
"""
mtp_word_emb_replica_id
=
(
1
,
# copy of embedding in pre processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
word_emb_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
word_emb_weight_key
]
sharded_state_dict
[
word_emb_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
word_emb_weight
,
key
=
word_emb_weight_key
,
replica_id
=
mtp_word_emb_replica_id
,
allow_shape_mismatch
=
True
,
)
def
tie_output_layer_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
output_layer_weight
:
Tensor
,
output_layer_weight_key
:
str
)
->
None
:
"""tie the output layer of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
output_layer_weight (Tensor): weight of the output layer.
output_layer_weight_key (str): key of the output layer in the sharded state dict.
Returns: None, acts in-place
"""
mtp_output_layer_replica_id
=
(
1
,
# copy of output layer in post processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
output_layer_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
output_layer_weight_key
]
sharded_state_dict
[
output_layer_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
output_layer_weight
,
key
=
output_layer_weight_key
,
replica_id
=
mtp_output_layer_replica_id
,
allow_shape_mismatch
=
True
,
)
def
roll_tensor
(
tensor
,
shifts
=-
1
,
dims
=-
1
):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor
=
torch
.
roll
(
tensor
,
shifts
=
shifts
,
dims
=
dims
)
rolled_tensor
.
select
(
dims
,
shifts
).
fill_
(
0
)
return
rolled_tensor
,
rolled_tensor
.
sum
()
class
MTPLossLoggingHelper
:
"""Helper class for logging MTP losses."""
tracker
=
{}
@
staticmethod
def
save_loss_to_tracker
(
loss
:
torch
.
Tensor
,
layer_number
:
int
,
num_layers
:
int
,
reduce_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
avg_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
):
"""Save the mtp loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip mtp loss logging if layer_number is None.
if
layer_number
is
None
:
return
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
tracker
[
"values"
]
=
torch
.
zeros
(
num_layers
,
device
=
loss
.
device
)
tracker
[
"values"
][
layer_number
]
+=
loss
.
detach
()
tracker
[
"reduce_group"
]
=
reduce_group
tracker
[
"avg_group"
]
=
avg_group
def
clean_loss_in_tracker
():
"""Clear the mtp losses."""
tracker
=
MTPLossLoggingHelper
.
tracker
tracker
[
"values"
].
zero_
()
tracker
[
"reduce_group"
]
=
None
tracker
[
"avg_group"
]
=
None
def
reduce_loss_in_tracker
():
"""Collect and reduce the mtp losses across ranks."""
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
values
=
tracker
[
"values"
]
# Reduce mtp losses across ranks.
if
tracker
.
get
(
'reduce_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
.
get
(
'reduce_group'
))
if
tracker
.
get
(
'avg_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
[
'avg_group'
],
op
=
torch
.
distributed
.
ReduceOp
.
AVG
)
def
track_mtp_metrics
(
loss_scale
,
iteration
,
writer
,
wandb_writer
=
None
,
total_loss_dict
=
None
):
"""Track the Multi-Token Prediction (MTP) metrics for logging."""
MTPLossLoggingHelper
.
reduce_loss_in_tracker
()
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
mtp_losses
=
tracker
[
"values"
]
*
loss_scale
mtp_num_layers
=
mtp_losses
.
shape
[
0
]
for
i
in
range
(
mtp_num_layers
):
name
=
f
"mtp_
{
i
+
1
}
loss"
loss
=
mtp_losses
[
i
]
if
total_loss_dict
is
not
None
:
total_loss_dict
[
name
]
=
loss
if
writer
is
not
None
:
writer
.
add_scalar
(
name
,
loss
,
iteration
)
if
wandb_writer
is
not
None
:
wandb_writer
.
log
({
f
"
{
name
}
"
:
loss
},
iteration
)
MTPLossLoggingHelper
.
clean_loss_in_tracker
()
@
dataclass
class
MultiTokenPredictionLayerSubmodules
:
"""
Dataclass for specifying the submodules of a MultiTokenPrediction module.
Args:
hnorm (Union[ModuleSpec, type]): Specification or instance of the
hidden states normalization to be applied.
enorm (Union[ModuleSpec, type]): Specification or instance of the
embedding normalization to be applied.
eh_proj (Union[ModuleSpec, type]): Specification or instance of the
linear projection to be applied.
transformer_layer (Union[ModuleSpec, type]): Specification
or instance of the transformer block to be applied.
"""
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
layer_norm
:
Union
[
ModuleSpec
,
type
]
=
None
def
get_mtp_layer_spec
(
transformer_layer_spec
:
ModuleSpec
,
use_transformer_engine
:
bool
)
->
ModuleSpec
:
"""Get the MTP layer spec.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
use_transformer_engine
:
assert
HAVE_TE
,
"transformer_engine should be installed if use_transformer_engine is True"
layer_norm_impl
=
TENorm
column_parallel_linear_impl
=
TEColumnParallelLinear
else
:
layer_norm_impl
=
LNImpl
column_parallel_linear_impl
=
ColumnParallelLinear
mtp_layer_spec
=
ModuleSpec
(
module
=
MultiTokenPredictionLayer
,
submodules
=
MultiTokenPredictionLayerSubmodules
(
enorm
=
layer_norm_impl
,
hnorm
=
layer_norm_impl
,
eh_proj
=
column_parallel_linear_impl
,
transformer_layer
=
transformer_layer_spec
,
layer_norm
=
layer_norm_impl
,
),
)
return
mtp_layer_spec
def
get_mtp_layer_offset
(
config
:
TransformerConfig
)
->
int
:
"""Get the offset of the MTP layer."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
return
0
def
get_mtp_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
"""Get the number of MTP layers to build."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
if
mpu
.
is_pipeline_last_stage
():
return
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
else
:
return
0
class
MTPLossAutoScaler
(
torch
.
autograd
.
Function
):
"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""
main_loss_backward_scale
:
torch
.
Tensor
=
torch
.
tensor
(
1.0
)
@
staticmethod
def
forward
(
ctx
,
output
:
torch
.
Tensor
,
mtp_loss
:
torch
.
Tensor
):
"""Preserve the mtp by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
mtp_loss (torch.Tensor): The mtp loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx
.
save_for_backward
(
mtp_loss
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
):
"""Compute and scale the gradient for mtp loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
gradient.
"""
(
mtp_loss
,)
=
ctx
.
saved_tensors
mtp_loss_backward_scale
=
MTPLossAutoScaler
.
main_loss_backward_scale
scaled_mtp_loss_grad
=
torch
.
ones_like
(
mtp_loss
)
*
mtp_loss_backward_scale
return
grad_output
,
scaled_mtp_loss_grad
@
staticmethod
def
set_loss_scale
(
scale
:
torch
.
Tensor
):
"""set the scale of the mtp loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
MTPLossAutoScaler
.
main_loss_backward_scale
=
scale
class
MultiTokenPredictionLayer
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredictionLayerSubmodules
,
layer_number
:
int
=
1
,
):
super
().
__init__
(
config
=
config
)
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self_attention_spec
=
self
.
submodules
.
transformer_layer
.
submodules
.
self_attention
attn_mask_type
=
self_attention_spec
.
params
.
get
(
'attn_mask_type'
,
''
)
assert
attn_mask_type
in
SUPPORTED_ATTN_MASK
,
(
f
"Multi-Token Prediction (MTP) is not jet supported with "
+
f
"
{
attn_mask_type
}
attention mask type."
+
f
"The supported attention mask types are
{
SUPPORTED_ATTN_MASK
}
."
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
# For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation
# of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input,
# so the input's shape is [s, b, 2*h].
# The output will be send to the following transformer layer,
# so the output's shape should be [s, b, h].
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
config
.
hidden_size
*
2
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
False
,
skip_bias_add
=
False
,
is_expert
=
False
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
)
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
layer_norm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
def
forward
(
self
,
decoder_input
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
):
"""
Perform the forward pass through the MTP layer.
Args:
hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
At the (k - 1)-th MTP module, the i-th element of decoder input is
the embedding of (i + K)-th tocken.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert
context
is
None
,
f
"multi token prediction + cross attention is not yet supported."
assert
(
packed_seq_params
is
None
),
f
"multi token prediction + sequence packing is not yet supported."
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
if
self
.
config
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
if
self
.
config
.
fp8
:
import
transformer_engine
# To keep out TE dependency when not training in fp8
if
self
.
config
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
self
.
config
.
fp8
==
"hybrid"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"E4M3 and HYBRID are the only supported FP8 formats."
)
fp8_recipe
=
TEDelayedScaling
(
config
=
self
.
config
,
fp8_format
=
fp8_format
,
override_linear_precision
=
(
False
,
False
,
not
self
.
config
.
fp8_wgrad
),
)
fp8_group
=
None
if
parallel_state
.
model_parallel_is_initialized
():
fp8_group
=
parallel_state
.
get_amax_reduction_group
(
with_context_parallel
=
True
,
tp_only_amax_red
=
self
.
tp_only_amax_red
)
fp8_context
=
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
fp8_group
)
else
:
fp8_context
=
nullcontext
()
with
rng_context
,
fp8_context
:
decoder_input
=
self
.
enorm
(
decoder_input
)
decoder_input
=
make_viewless_tensor
(
inp
=
decoder_input
,
requires_grad
=
True
,
keep_graph
=
True
)
hidden_states
=
self
.
hnorm
(
hidden_states
)
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
# At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states
# and the (i + K)-th tocken's embedding, and combine them with linear projection.
hidden_states
=
torch
.
cat
((
decoder_input
,
hidden_states
),
-
1
)
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
# For tensor parallel, all gather after linear_fc.
hidden_states
=
all_gather_last_dim_from_tensor_parallel_region
(
hidden_states
)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if
self
.
sequence_parallel
:
hidden_states
=
scatter_to_sequence_parallel_region
(
hidden_states
)
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
)
# Layer norm before shared head layer.
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
return
hidden_states
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction layer.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
return
sharded_state_dict
@
dataclass
class
MultiTokenPredictionBlockSubmodules
:
"""
Dataclass for specifying the submodules of a multi token prediction block.
This class defines the structure for configuring the layers, allowing for
flexible and customizable architecture designs.
Args:
layer_specs (List[ModuleSpec], optional): A list of module specifications for
the layers within the multi token prediction block. Each specification typically
defines a complete multi token prediction layer (e.g., shared embedding,
projection matrix, transformer block, shared output head).
"""
layer_specs
:
List
[
ModuleSpec
]
=
None
def
_get_mtp_block_submodules
(
config
:
TransformerConfig
,
spec
:
Union
[
MultiTokenPredictionBlockSubmodules
,
ModuleSpec
]
)
->
MultiTokenPredictionBlockSubmodules
:
"""
Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the
multi token prediction block submodules.
Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec.
Returns:
MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block.
"""
# Transformer block submodules.
if
isinstance
(
spec
,
MultiTokenPredictionBlockSubmodules
):
return
spec
elif
isinstance
(
spec
,
ModuleSpec
):
if
issubclass
(
spec
.
module
,
MultiTokenPredictionBlock
):
return
spec
.
submodules
else
:
raise
Exception
(
f
"specialize for
{
spec
.
module
.
__name__
}
."
)
else
:
raise
Exception
(
f
"specialize for
{
type
(
spec
).
__name__
}
."
)
class
MultiTokenPredictionBlock
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
]
):
super
().
__init__
(
config
=
config
)
self
.
submodules
=
_get_mtp_block_submodules
(
config
,
spec
)
self
.
mtp_loss_scaling_factor
=
config
.
mtp_loss_scaling_factor
self
.
_build_layers
()
assert
len
(
self
.
layers
)
>
0
,
"MultiTokenPredictionBlock must have at least one layer."
def
_build_layers
(
self
):
def
build_layer
(
layer_spec
,
layer_number
):
return
build_module
(
layer_spec
,
config
=
self
.
config
,
layer_number
=
layer_number
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
layer_spec
,
i
+
1
)
for
i
,
layer_spec
in
enumerate
(
self
.
submodules
.
layer_specs
)
]
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
embedding
=
None
,
output_layer
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
compute_language_model_loss
=
None
,
)
->
Tensor
:
"""
Perform the forward pass through all of the MTP modules.
Args:
hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
Returns:
(Tensor): The mtp loss tensor of shape [b, s].
"""
assert
(
labels
is
not
None
),
f
"labels should not be None for calculating multi token prediction loss."
if
loss_mask
is
None
:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask
=
torch
.
ones_like
(
labels
)
hidden_states_main_model
=
hidden_states
for
layer_number
in
range
(
len
(
self
.
layers
)):
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids
,
_
=
roll_tensor
(
input_ids
,
shifts
=-
1
,
dims
=-
1
)
# embedding
decoder_input
=
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
# norm, linear projection and transformer
hidden_states
=
self
.
layers
[
layer_number
](
decoder_input
=
decoder_input
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# output
mtp_logits
,
_
=
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels
,
_
=
roll_tensor
(
labels
,
shifts
=-
1
,
dims
=-
1
)
loss_mask
,
num_tokens
=
roll_tensor
(
loss_mask
,
shifts
=-
1
,
dims
=-
1
)
mtp_loss
=
compute_language_model_loss
(
labels
,
mtp_logits
)
mtp_loss
=
loss_mask
*
mtp_loss
if
self
.
training
:
MTPLossLoggingHelper
.
save_loss_to_tracker
(
torch
.
sum
(
mtp_loss
)
/
num_tokens
,
layer_number
,
self
.
config
.
mtp_num_layers
,
avg_group
=
parallel_state
.
get_tensor_and_context_parallel_group
(),
)
mtp_loss_scale
=
self
.
mtp_loss_scaling_factor
/
self
.
config
.
mtp_num_layers
if
self
.
config
.
calculate_per_token_loss
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
)
else
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
/
num_tokens
)
return
hidden_states_main_model
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction module.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction module.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
layer_prefix
=
f
'
{
prefix
}
layers.'
for
layer
in
self
.
layers
:
offset
=
get_mtp_layer_offset
(
self
.
config
)
sharded_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
}
.'
state_dict_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
-
offset
}
.'
sharded_pp_offset
=
[]
layer_sharded_state_dict
=
layer
.
sharded_state_dict
(
state_dict_prefix
,
sharded_pp_offset
,
metadata
)
replace_prefix_for_sharding
(
layer_sharded_state_dict
,
state_dict_prefix
,
sharded_prefix
)
sharded_state_dict
.
update
(
layer_sharded_state_dict
)
return
sharded_state_dict
dcu_megatron/core/transformer/transformer_block.py
View file @
770fa304
...
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"
num_nextn_predict
_layers"
,
0
)
>
0
:
if
getattr
(
config
,
"
mtp_num
_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
770fa304
from
functools
import
wraps
from
dataclasses
import
dataclass
from
megatron.training
import
get_args
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
@
dataclass
class
ExtraTransformerConfig
:
def
transformer_config_post_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
):
fn
(
self
)
args
=
get_args
()
"""Number of Multi-Token Prediction (MTP) Layers."""
self
.
mtp_num_layers
=
args
.
mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self
.
mtp_loss_scaling_factor
=
args
.
mtp_loss_scaling_factor
##################
#
multi-token prediction
#
flux
##################
num_nextn_predict_layers
:
int
=
0
"""The number of multi-token prediction layers"""
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
return
wrapper
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
@
dataclass
class
ExtraTransformerConfig
:
##################
# multi-token prediction
##################
mtp_num_layers
:
Optional
[
int
]
=
None
"""Number of Multi-Token Prediction (MTP) Layers."""
share_mtp_embedding_and_output_weight
:
bool
=
Fals
e
"""
share embedding and output weight with mtp layer
."""
mtp_loss_scaling_factor
:
Optional
[
float
]
=
Non
e
"""
Weighting factor of Multi-Token Prediction (MTP) loss
."""
##################
# flux
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment