Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
770fa304
Commit
770fa304
authored
Apr 25, 2025
by
dongcl
Browse files
修改mtp
parent
8096abd4
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
140 deletions
+106
-140
dcu_megatron/core/utils.py
dcu_megatron/core/utils.py
+0
-30
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-8
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+82
-88
pretrain_gpt.py
pretrain_gpt.py
+14
-14
No files found.
dcu_megatron/core/utils.py
View file @
770fa304
...
@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
...
@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if
check_equality
:
if
check_equality
:
return
get_flux_version
()
>=
PkgVersion
(
version
)
return
get_flux_version
()
>=
PkgVersion
(
version
)
return
get_flux_version
()
>
PkgVersion
(
version
)
return
get_flux_version
()
>
PkgVersion
(
version
)
def
tensor_slide
(
tensor
:
Optional
[
torch
.
Tensor
],
num_slice
:
int
,
dims
:
Union
[
int
,
List
[
int
]]
=
-
1
,
step
:
int
=
1
,
return_first
=
False
,
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
"""通用滑动窗口函数,支持任意维度"""
if
tensor
is
None
:
# return `List[None]` to avoid NoneType Error
return
[
None
]
*
(
num_slice
+
1
)
if
num_slice
==
0
:
return
[
tensor
]
window_size
=
tensor
.
shape
[
-
1
]
-
num_slice
dims
=
[
dims
]
if
isinstance
(
dims
,
int
)
else
sorted
(
dims
,
reverse
=
True
)
# 连续多维度滑动
slices
=
[]
for
i
in
range
(
0
,
tensor
.
size
(
dims
[
-
1
])
-
window_size
+
1
,
step
):
slice_obj
=
[
slice
(
None
)]
*
tensor
.
dim
()
for
dim
in
dims
:
slice_obj
[
dim
]
=
slice
(
i
,
i
+
window_size
)
slices
.
append
(
tensor
[
tuple
(
slice_obj
)])
if
return_first
:
return
slices
return
slices
dcu_megatron/training/arguments.py
View file @
770fa304
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def
_add_mtp_args
(
parser
):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--mtp-num-layers'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--mtp-loss-scale'
,
type
=
float
,
default
=
0.3
,
help
=
'Multi-Token prediction loss scale'
)
help
=
'Number of Multi-Token Prediction (MTP) Layers.'
group
.
add_argument
(
'--recompute-mtp-norm'
,
action
=
'store_true'
,
default
=
False
,
'MTP extends the prediction scope to multiple future tokens at each position.'
help
=
'Multi-Token prediction recompute norm'
)
'This MTP implementation sequentially predict additional tokens '
group
.
add_argument
(
'--recompute-mtp-layer'
,
action
=
'store_true'
,
default
=
False
,
'by using D sequential modules to predict D additional tokens.'
)
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--mtp-loss-scaling-factor'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Scaling factor of Multi-Token Prediction (MTP) loss. '
help
=
'Main model share embedding and output weight with mtp layer.'
)
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.'
)
return
parser
return
parser
...
...
dcu_megatron/training/utils.py
View file @
770fa304
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args
=
get_args
()
args
=
get_args
()
def
_broadcast
(
item
):
def
_broadcast
(
item
):
if
item
is
not
None
:
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
batch
=
{
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
}
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
else
:
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
if
args
.
create_attention_mask_in_dataloader
:
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
empty
(
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
labels
)
_broadcast
(
tokens
)
_broadcast
(
loss_mask
)
_broadcast
(
labels
)
_broadcast
(
attention_mask
)
_broadcast
(
loss_mask
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
elif
mpu
.
is_pipeline_first_stage
():
loss_mask
=
None
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
tokens
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif
mpu
.
is_pipeline_last_stage
():
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if
args
.
num_nextn_predict_layers
:
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
tokens
)
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
_broadcast
(
position_ids
)
else
:
else
:
position_ids
=
None
tokens
=
None
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
_broadcast
(
labels
)
'labels'
:
labels
,
_broadcast
(
loss_mask
)
'loss_mask'
:
loss_mask
,
_broadcast
(
attention_mask
)
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
batch
=
{
}
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
return
batch
pretrain_gpt.py
View file @
770fa304
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_layer_with_transformer_engine_spec
,
)
)
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
from
dcu_megatron.core.transformer.mtp.mtp_spec
import
get_mtp_spec
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron
import
megatron_adaptor
from
dcu_megatron
import
megatron_adaptor
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
# Define the mtp layer spec
# Define the mtp layer spec
if
isinstance
(
transformer_layer_spec
,
TransformerBlockSubmodules
):
mtp_block_spec
=
None
mtp_transformer_layer_spec
=
transformer_layer_spec
.
layer_specs
[
-
1
]
if
args
.
mtp_num_layers
is
not
None
:
else
:
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
mtp_transformer_layer_spec
=
transformer_
layer_spec
mtp_
block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_
transformer_
engine
=
use_te
)
with
build_model_context
(
**
build_model_context_args
):
with
build_model_context
(
**
build_model_context_args
):
config
.
mtp_spec
=
get_mtp_spec
(
mtp_transformer_layer_spec
,
use_te
=
use_te
)
model
=
GPTModel
(
model
=
GPTModel
(
config
=
config
,
config
=
config
,
transformer_layer_spec
=
transformer_layer_spec
,
transformer_layer_spec
=
transformer_layer_spec
,
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type
=
args
.
position_embedding_type
,
position_embedding_type
=
args
.
position_embedding_type
,
rotary_percent
=
args
.
rotary_percent
,
rotary_percent
=
args
.
rotary_percent
,
rotary_base
=
args
.
rotary_base
,
rotary_base
=
args
.
rotary_base
,
rope_scaling
=
args
.
use_rope_scaling
rope_scaling
=
args
.
use_rope_scaling
,
mtp_block_spec
=
mtp_block_spec
,
)
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0
(
model
)
print_rank_0
(
model
)
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args
=
get_args
()
args
=
get_args
()
losses
=
output_tensor
.
float
()
losses
=
output_tensor
.
float
()
if
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
>
0
:
loss_mask
=
tensor_slide
(
loss_mask
,
args
.
num_nextn_predict_layers
,
return_first
=
True
)[
0
]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
total_tokens
=
loss_mask
.
sum
()
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
with
stimer
:
with
stimer
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
if
args
.
use_legacy_models
:
labels
=
labels
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
loss_mask
=
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return
GPTDatasetConfig
(
return
GPTDatasetConfig
(
random_seed
=
args
.
seed
,
random_seed
=
args
.
seed
,
sequence_length
=
args
.
seq_length
+
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
,
sequence_length
=
args
.
seq_length
,
blend
=
blend
,
blend
=
blend
,
blend_per_split
=
blend_per_split
,
blend_per_split
=
blend_per_split
,
split
=
args
.
split
,
split
=
args
.
split
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment