Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
f3581a8d
Commit
f3581a8d
authored
Apr 30, 2025
by
sdwldchl
Browse files
patch for megatron 6ba97dd
parent
bc3d72d1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
53 deletions
+49
-53
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+6
-7
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+43
-46
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
f3581a8d
...
@@ -89,12 +89,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -89,12 +89,9 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
pass
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
from
..core.models.gpt.gpt_model
import
gpt_model_forward
# GPT Model
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
gpt_model_forward
)
...
@@ -116,9 +113,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -116,9 +113,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
#
apply_wrapper=True)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
@@ -174,6 +171,8 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -174,6 +171,8 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear
)
FluxRowParallelLinear
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear"
,
FluxColumnParallelLinear
)
def
patch_pipeline_parallel
(
self
):
def
patch_pipeline_parallel
(
self
):
pass
pass
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
f3581a8d
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Optional
from
typing
import
Optional
from
functools
import
wraps
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
deprecate_inference_params
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
return
wrapper
def
gpt_model_forward
(
def
gpt_model_forward
(
...
@@ -52,14 +17,16 @@ def gpt_model_forward(
...
@@ -52,14 +17,16 @@ def gpt_model_forward(
attention_mask
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
:
Inference
Params
=
None
,
inference_
context
:
Base
Inference
Context
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
through the embedding layer, and then the deco
e
der and finally into the post
processing layer (optional).
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
It either returns the Loss values if labels are given or the final hidden units
...
@@ -71,6 +38,8 @@ def gpt_model_forward(
...
@@ -71,6 +38,8 @@ def gpt_model_forward(
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
# Decoder embedding.
if
decoder_input
is
not
None
:
if
decoder_input
is
not
None
:
pass
pass
...
@@ -86,28 +55,43 @@ def gpt_model_forward(
...
@@ -86,28 +55,43 @@ def gpt_model_forward(
rotary_pos_cos
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_context
:
assert
(
inference_context
.
is_static_batching
()
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_
params
.
max_sequence_length
,
inference_
context
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
params
.
max_sequence_length
),
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
context
.
max_sequence_length
),
)
)
else
:
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_
params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
inference_
context
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
)
elif
self
.
position_embedding_type
==
'mrope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
training
or
not
self
.
config
.
flash_decode
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
position_ids
,
self
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
rotary_pos_cos
is
not
None
and
inference_params
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
self
.
training
):
):
sequence_len_offset
=
torch
.
tensor
(
sequence_len_offset
=
torch
.
tensor
(
[
inference_
params
.
sequence_len_offset
]
*
inference_
params
.
current_batch_size
,
[
inference_
context
.
sequence_len_offset
]
*
inference_
context
.
current_batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
)
...
@@ -118,7 +102,7 @@ def gpt_model_forward(
...
@@ -118,7 +102,7 @@ def gpt_model_forward(
hidden_states
=
self
.
decoder
(
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inference_
params
=
inference_
params
,
inference_
context
=
inference_
context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
rotary_pos_sin
=
rotary_pos_sin
,
...
@@ -127,6 +111,12 @@ def gpt_model_forward(
...
@@ -127,6 +111,12 @@ def gpt_model_forward(
**
(
extra_block_kwargs
or
{}),
**
(
extra_block_kwargs
or
{}),
)
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
# logits and loss
output_weight
=
None
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
if
self
.
share_embeddings_and_output_weights
:
...
@@ -164,6 +154,13 @@ def gpt_model_forward(
...
@@ -164,6 +154,13 @@ def gpt_model_forward(
if
not
self
.
post_process
:
if
not
self
.
post_process
:
return
hidden_states
return
hidden_states
if
(
not
self
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
self
.
output_layer
(
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment