Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7b78db5d
Commit
7b78db5d
authored
Mar 12, 2025
by
dongcl
Browse files
支持deepseek v3训练
parent
c099d843
Pipeline
#2478
passed with stage
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
9 deletions
+29
-9
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+1
-1
megatron/core/transformer/attention.py
megatron/core/transformer/attention.py
+18
-5
megatron/core/transformer/moe/router.py
megatron/core/transformer/moe/router.py
+1
-1
megatron/core/transformer/mtp/multi_token_predictor.py
megatron/core/transformer/mtp/multi_token_predictor.py
+2
-1
megatron/core/transformer/multi_latent_attention.py
megatron/core/transformer/multi_latent_attention.py
+4
-0
megatron/training/arguments.py
megatron/training/arguments.py
+2
-1
megatron/training/tokenizer/tokenizer.py
megatron/training/tokenizer/tokenizer.py
+1
-0
No files found.
megatron/core/models/gpt/gpt_model.py
View file @
7b78db5d
...
@@ -461,7 +461,7 @@ class GPTModel(LanguageModule):
...
@@ -461,7 +461,7 @@ class GPTModel(LanguageModule):
if
(
if
(
self
.
num_nextn_predict_layers
self
.
num_nextn_predict_layers
and
getattr
(
self
.
decoder
,
final_layernorm
,
None
)
is
not
None
and
getattr
(
self
.
decoder
,
"
final_layernorm
"
,
None
)
is
not
None
):
):
# move block main model final norms here
# move block main model final norms here
hidden_states
=
self
.
decoder
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
decoder
.
final_layernorm
(
hidden_states
)
...
...
megatron/core/transformer/attention.py
View file @
7b78db5d
...
@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC):
...
@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC):
self
.
num_attention_heads_per_partition
=
divide
(
self
.
config
.
num_attention_heads
,
world_size
)
self
.
num_attention_heads_per_partition
=
divide
(
self
.
config
.
num_attention_heads
,
world_size
)
self
.
num_query_groups_per_partition
=
divide
(
self
.
config
.
num_query_groups
,
world_size
)
self
.
num_query_groups_per_partition
=
divide
(
self
.
config
.
num_query_groups
,
world_size
)
# To support both CUDA Graphs and key value with different hidden size
self
.
key_hidden_size
=
self
.
hidden_size_per_attention_head
self
.
val_hidden_size
=
self
.
hidden_size_per_attention_head
self
.
core_attention
=
build_module
(
self
.
core_attention
=
build_module
(
submodules
.
core_attention
,
submodules
.
core_attention
,
config
=
self
.
config
,
config
=
self
.
config
,
...
@@ -209,10 +213,10 @@ class Attention(MegatronModule, ABC):
...
@@ -209,10 +213,10 @@ class Attention(MegatronModule, ABC):
inf_max_seq_length
=
inference_params
.
max_sequence_length
inf_max_seq_length
=
inference_params
.
max_sequence_length
inf_max_batch_size
=
inference_params
.
max_batch_size
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_length
,
inf_max_batch_size
,
key
.
shape
[
-
1
]
,
key
.
dtype
inf_max_seq_length
,
inf_max_batch_size
,
self
.
key_hidden_size
,
key
.
dtype
)
)
inference_value_memory
=
self
.
_allocate_memory
(
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_length
,
inf_max_batch_size
,
value
.
shape
[
-
1
]
,
value
.
dtype
inf_max_seq_length
,
inf_max_batch_size
,
self
.
val_hidden_size
,
value
.
dtype
)
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_key_memory
,
...
@@ -234,7 +238,10 @@ class Attention(MegatronModule, ABC):
...
@@ -234,7 +238,10 @@ class Attention(MegatronModule, ABC):
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key
.
size
(
0
)
sequence_end
=
sequence_start
+
key
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
),
(
"Current sequence length is longer than expected maximum sequence length! "
"Increase inference_max_seq_length."
)
if
self
.
config
.
flash_decode
:
if
self
.
config
.
flash_decode
:
assert
(
assert
(
...
@@ -245,7 +252,7 @@ class Attention(MegatronModule, ABC):
...
@@ -245,7 +252,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_sin_q
=
rotary_pos_sin
[
sequence_end
-
1
:
sequence_end
]
rotary_pos_sin_q
=
rotary_pos_sin
[
sequence_end
-
1
:
sequence_end
]
rotary_pos_cos_k
=
rotary_pos_cos
[
sequence_end
-
1
:
sequence_end
]
rotary_pos_cos_k
=
rotary_pos_cos
[
sequence_end
-
1
:
sequence_end
]
rotary_pos_sin_k
=
rotary_pos_sin
[
sequence_end
-
1
:
sequence_end
]
rotary_pos_sin_k
=
rotary_pos_sin
[
sequence_end
-
1
:
sequence_end
]
else
:
else
:
# Prefill
rotary_pos_cos_q
=
rotary_pos_cos
[:
sequence_end
]
rotary_pos_cos_q
=
rotary_pos_cos
[:
sequence_end
]
rotary_pos_sin_q
=
rotary_pos_sin
[:
sequence_end
]
rotary_pos_sin_q
=
rotary_pos_sin
[:
sequence_end
]
rotary_pos_cos_k
=
rotary_pos_cos
[:
sequence_end
]
rotary_pos_cos_k
=
rotary_pos_cos
[:
sequence_end
]
...
@@ -394,7 +401,13 @@ class Attention(MegatronModule, ABC):
...
@@ -394,7 +401,13 @@ class Attention(MegatronModule, ABC):
return
output
,
bias
return
output
,
bias
query
,
key
,
value
,
rotary_pos_emb
,
attn_mask_type
=
self
.
_adjust_key_value_for_inference
(
query
,
key
,
value
,
rotary_pos_emb
,
attn_mask_type
=
self
.
_adjust_key_value_for_inference
(
inference_params
,
query
,
key
,
value
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
inference_params
,
query
,
key
,
value
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
)
)
if
packed_seq_params
is
not
None
:
if
packed_seq_params
is
not
None
:
...
...
megatron/core/transformer/moe/router.py
View file @
7b78db5d
...
@@ -322,7 +322,7 @@ class TopKRouter(Router):
...
@@ -322,7 +322,7 @@ class TopKRouter(Router):
scores
,
routing_map
=
self
.
aux_loss_load_balancing
(
logits
)
scores
,
routing_map
=
self
.
aux_loss_load_balancing
(
logits
)
elif
self
.
routing_type
==
"seq_aux_loss"
:
elif
self
.
routing_type
==
"seq_aux_loss"
:
scores
,
routing_map
=
self
.
seq_aux_loss_load_balancing
(
logits
,
bsz
,
seq_length
)
scores
,
routing_map
=
self
.
seq_aux_loss_load_balancing
(
logits
,
bsz
,
seq_length
)
elif
self
.
routing_type
==
"none"
:
elif
self
.
routing_type
in
[
"none"
,
"noaux_tc"
]
:
# A naive top-k routing without load balancing
# A naive top-k routing without load balancing
scores
,
routing_map
,
_
=
topk_softmax_with_capacity
(
scores
,
routing_map
,
_
=
topk_softmax_with_capacity
(
logits
,
logits
,
...
...
megatron/core/transformer/mtp/multi_token_predictor.py
View file @
7b78db5d
...
@@ -173,7 +173,7 @@ class MultiTokenPredictor(MegatronModule):
...
@@ -173,7 +173,7 @@ class MultiTokenPredictor(MegatronModule):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
:
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
inference_params
is
not
None
:
if
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
else
:
...
@@ -184,6 +184,7 @@ class MultiTokenPredictor(MegatronModule):
...
@@ -184,6 +184,7 @@ class MultiTokenPredictor(MegatronModule):
rotary_seq_len
*=
self
.
config
.
context_parallel_size
rotary_seq_len
*=
self
.
config
.
context_parallel_size
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
if
self
.
recompute_layer_norm
:
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
=
CheckpointWithoutOutput
()
self
.
enorm_ckpt
=
CheckpointWithoutOutput
()
enorm_output
=
self
.
enorm_ckpt
.
checkpoint
(
self
.
enorm
,
False
,
decoder_input
)
enorm_output
=
self
.
enorm_ckpt
.
checkpoint
(
self
.
enorm
,
False
,
decoder_input
)
...
...
megatron/core/transformer/multi_latent_attention.py
View file @
7b78db5d
...
@@ -68,6 +68,10 @@ class MultiLatentAttention(Attention):
...
@@ -68,6 +68,10 @@ class MultiLatentAttention(Attention):
self
.
q_head_dim
=
self
.
config
.
qk_head_dim
+
self
.
config
.
qk_pos_emb_head_dim
self
.
q_head_dim
=
self
.
config
.
qk_head_dim
+
self
.
config
.
qk_pos_emb_head_dim
# Overwrite the base class kv shape to support MLA inference
self
.
key_hidden_size
=
self
.
q_head_dim
self
.
val_hidden_size
=
self
.
config
.
v_head_dim
mscale
=
_yarn_get_mscale
(
self
.
config
.
rotary_scaling_factor
,
self
.
config
.
mscale
)
mscale
=
_yarn_get_mscale
(
self
.
config
.
rotary_scaling_factor
,
self
.
config
.
mscale
)
self
.
softmax_scale
=
mscale
*
mscale
/
math
.
sqrt
(
self
.
q_head_dim
)
self
.
softmax_scale
=
mscale
*
mscale
/
math
.
sqrt
(
self
.
q_head_dim
)
...
...
megatron/training/arguments.py
View file @
7b78db5d
...
@@ -1858,7 +1858,8 @@ def _add_tokenizer_args(parser):
...
@@ -1858,7 +1858,8 @@ def _add_tokenizer_args(parser):
'QwenTokenizer'
,
'QwenTokenizer'
,
'TikTokenizer'
,
'TikTokenizer'
,
'MultimodalTokenizer'
,
'MultimodalTokenizer'
,
'NullTokenizer'
],
'NullTokenizer'
,
'DeepSeekV2Tokenizer'
],
help
=
'What type of tokenizer to use.'
)
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--tokenizer-model'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tokenizer-model'
,
type
=
str
,
default
=
None
,
help
=
'Sentencepiece tokenizer model.'
)
help
=
'Sentencepiece tokenizer model.'
)
...
...
megatron/training/tokenizer/tokenizer.py
View file @
7b78db5d
...
@@ -16,6 +16,7 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
...
@@ -16,6 +16,7 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
from
transformers
import
Qwen2Tokenizer
from
transformers
import
Qwen2Tokenizer
from
transformers
import
AutoTokenizer
def
build_tokenizer
(
args
,
**
kwargs
):
def
build_tokenizer
(
args
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment