Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f00f0256
Commit
f00f0256
authored
Mar 07, 2025
by
dongcl
Browse files
deepseek mtp bug解决
parent
627a739f
Pipeline
#2462
passed with stage
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
78 deletions
+85
-78
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+3
-0
megatron/core/transformer/transformer_block.py
megatron/core/transformer/transformer_block.py
+3
-1
megatron/training/utils.py
megatron/training/utils.py
+76
-77
pretrain_gpt.py
pretrain_gpt.py
+3
-0
No files found.
megatron/core/models/gpt/gpt_model.py
View file @
f00f0256
...
...
@@ -17,9 +17,11 @@ from megatron.core.models.common.language_module.language_module import Language
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
build_module
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
megatron.core.extensions.transformer_engine
import
TENorm
class
GPTModel
(
LanguageModule
):
...
...
@@ -137,6 +139,7 @@ class GPTModel(LanguageModule):
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
num_nextn_predict_layers
=
num_nextn_predict_layers
)
# Output
...
...
megatron/core/transformer/transformer_block.py
View file @
f00f0256
...
...
@@ -178,6 +178,7 @@ class TransformerBlock(MegatronModule):
post_layer_norm
:
bool
=
True
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
num_nextn_predict_layers
:
int
=
0
):
super
().
__init__
(
config
=
config
)
...
...
@@ -185,6 +186,7 @@ class TransformerBlock(MegatronModule):
self
.
post_layer_norm
=
post_layer_norm
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
# Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers).
# Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the
# number of microbatches. Multiple CUDA graphs per layer is required to support
...
...
@@ -246,7 +248,7 @@ class TransformerBlock(MegatronModule):
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
move_final_norm_out_of_block
=
args
.
num_nextn_predict_layers
>
0
move_final_norm_out_of_block
=
self
.
num_nextn_predict_layers
>
0
if
self
.
submodules
.
layer_norm
and
self
.
post_process
and
self
.
post_layer_norm
and
not
move_final_norm_out_of_block
:
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
layer_norm
,
...
...
megatron/training/utils.py
View file @
f00f0256
...
...
@@ -388,105 +388,104 @@ def get_batch_on_this_tp_rank(data_iterator):
args
=
get_args
()
def
_broadcast
(
item
):
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
else
:
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
...
...
pretrain_gpt.py
View file @
f00f0256
...
...
@@ -37,6 +37,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
)
from
megatron.core.transformer.mtp.mtp_spec
import
get_mtp_spec
from
megatron.core.utils
import
tensor_slide
import
torch._dynamo
torch
.
_dynamo
.
config
.
suppress_errors
=
True
...
...
@@ -190,6 +191,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args
=
get_args
()
losses
=
output_tensor
.
float
()
if
args
.
num_nextn_predict_layers
>
0
:
loss_mask
=
tensor_slide
(
loss_mask
,
args
.
num_nextn_predict_layers
,
return_first
=
True
)[
0
]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment