Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
0b492884
Commit
0b492884
authored
Apr 09, 2025
by
dongcl
Browse files
support for removing wrappers
parent
b0b00f4a
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
404 deletions
+72
-404
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+19
-3
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+34
-4
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+19
-397
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
0b492884
...
...
@@ -24,15 +24,26 @@ class MegatronAdaptation:
# MegatronAdaptation.post_execute()
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
):
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
"""
Register adaptations into collection.
"""
if
orig_func_name
not
in
cls
.
_patch_info_collection
:
from
.patch_utils
import
Patch
cls
.
_patch_info_collection
[
orig_func_name
]
=
Patch
(
orig_func_name
,
new_func
,
create_dummy
,
apply_wrapper
=
apply_wrapper
)
cls
.
_patch_info_collection
[
orig_func_name
]
=
Patch
(
orig_func_name
,
new_func
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
cls
.
_patch_info_collection
.
get
(
orig_func_name
).
set_patch_func
(
new_func
,
force_patch
,
apply_wrapper
=
apply_wrapper
)
cls
.
_patch_info_collection
.
get
(
orig_func_name
).
set_patch_func
(
new_func
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
classmethod
def
apply
(
cls
):
...
...
@@ -166,9 +177,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits'
,
VocabParallelCrossEntropy
.
calculate_predicted_logits
)
# _VocabParallelCrossEntropy
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward'
,
remove_origin_wrappers
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward'
,
staticmethod
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
0b492884
...
...
@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name):
class
Patch
:
def
__init__
(
self
,
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
False
):
def
__init__
(
self
,
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
split_name
=
orig_func_or_cls_name
.
rsplit
(
'.'
,
1
)
if
len
(
split_name
)
==
1
:
self
.
orig_module_name
,
self
.
orig_func_or_cls_name
=
orig_func_or_cls_name
,
None
...
...
@@ -28,9 +28,14 @@ class Patch:
self
.
patch_func_or_cls
=
None
self
.
wrappers
=
[]
if
new_func_or_cls
is
None
:
self
.
remove_origin_wrappers
=
False
if
(
new_func_or_cls
is
None
and
not
remove_origin_wrappers
):
new_func_or_cls
=
dummy_function_wrapper
(
orig_func_or_cls_name
)
self
.
set_patch_func
(
new_func_or_cls
,
apply_wrapper
=
apply_wrapper
)
self
.
set_patch_func
(
new_func_or_cls
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
self
.
is_applied
=
False
self
.
create_dummy
=
create_dummy
...
...
@@ -42,7 +47,27 @@ class Patch:
def
patch_func_id
(
self
):
return
id
(
self
.
patch_func_or_cls
)
def
set_patch_func
(
self
,
new_func_or_cls
,
force_patch
=
False
,
apply_wrapper
=
False
):
@
staticmethod
def
remove_wrappers
(
func
):
while
True
:
if
hasattr
(
func
,
'__wrapped__'
)
and
func
.
__wrapped__
is
not
None
:
func
=
func
.
__wrapped__
elif
hasattr
(
func
,
'__closure__'
)
and
func
.
__closure__
is
not
None
:
func
=
func
.
__closure__
[
0
].
cell_contents
else
:
return
func
return
func
def
set_patch_func
(
self
,
new_func_or_cls
=
None
,
force_patch
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
remove_origin_wrappers
:
self
.
remove_origin_wrappers
=
True
else
:
assert
new_func_or_cls
is
not
None
if
new_func_or_cls
is
None
:
return
if
(
apply_wrapper
or
(
hasattr
(
new_func_or_cls
,
'__name__'
)
and
new_func_or_cls
.
__name__
.
endswith
((
'wrapper'
,
'decorator'
)))
...
...
@@ -64,6 +89,11 @@ class Patch:
if
self
.
patch_func_or_cls
is
not
None
:
final_patch_func_or_cls
=
self
.
patch_func_or_cls
# remove original wrappers
if
self
.
remove_origin_wrappers
:
final_patch_func_or_cls
=
self
.
remove_wrappers
(
final_patch_func_or_cls
)
# add new wrappers
for
wrapper
in
self
.
wrappers
:
final_patch_func_or_cls
=
wrapper
(
final_patch_func_or_cls
)
...
...
dcu_megatron/training/arguments.py
View file @
0b492884
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment