"README_ORIGIN.md" did not exist on "81b67e8d9ea38fd8c0efc4fadc7e6e02c12ec00e"
Commit 0b492884 authored by dongcl's avatar dongcl
Browse files

support for removing wrappers

parent b0b00f4a
...@@ -24,15 +24,26 @@ class MegatronAdaptation: ...@@ -24,15 +24,26 @@ class MegatronAdaptation:
# MegatronAdaptation.post_execute() # MegatronAdaptation.post_execute()
@classmethod @classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False): def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False, remove_origin_wrappers=False):
""" """
Register adaptations into collection. Register adaptations into collection.
""" """
if orig_func_name not in cls._patch_info_collection: if orig_func_name not in cls._patch_info_collection:
from .patch_utils import Patch from .patch_utils import Patch
cls._patch_info_collection[orig_func_name] = Patch(orig_func_name, new_func, create_dummy, apply_wrapper=apply_wrapper) cls._patch_info_collection[orig_func_name] = Patch(
orig_func_name,
new_func,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else: else:
cls._patch_info_collection.get(orig_func_name).set_patch_func(new_func, force_patch, apply_wrapper=apply_wrapper) cls._patch_info_collection.get(orig_func_name).set_patch_func(
new_func,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@classmethod @classmethod
def apply(cls): def apply(cls):
...@@ -166,9 +177,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -166,9 +177,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits', MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
VocabParallelCrossEntropy.calculate_predicted_logits) VocabParallelCrossEntropy.calculate_predicted_logits)
# _VocabParallelCrossEntropy # _VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
remove_origin_wrappers=True)
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
staticmethod,
apply_wrapper=True)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
......
...@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name): ...@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name):
class Patch: class Patch:
def __init__(self, orig_func_or_cls_name, new_func_or_cls, create_dummy, apply_wrapper=False): def __init__(self, orig_func_or_cls_name, new_func_or_cls, create_dummy, apply_wrapper=False, remove_origin_wrappers=False):
split_name = orig_func_or_cls_name.rsplit('.', 1) split_name = orig_func_or_cls_name.rsplit('.', 1)
if len(split_name) == 1: if len(split_name) == 1:
self.orig_module_name, self.orig_func_or_cls_name = orig_func_or_cls_name, None self.orig_module_name, self.orig_func_or_cls_name = orig_func_or_cls_name, None
...@@ -28,9 +28,14 @@ class Patch: ...@@ -28,9 +28,14 @@ class Patch:
self.patch_func_or_cls = None self.patch_func_or_cls = None
self.wrappers = [] self.wrappers = []
if new_func_or_cls is None: self.remove_origin_wrappers = False
if (
new_func_or_cls is None
and not remove_origin_wrappers
):
new_func_or_cls = dummy_function_wrapper(orig_func_or_cls_name) new_func_or_cls = dummy_function_wrapper(orig_func_or_cls_name)
self.set_patch_func(new_func_or_cls, apply_wrapper=apply_wrapper)
self.set_patch_func(new_func_or_cls, apply_wrapper=apply_wrapper, remove_origin_wrappers=remove_origin_wrappers)
self.is_applied = False self.is_applied = False
self.create_dummy = create_dummy self.create_dummy = create_dummy
...@@ -42,7 +47,27 @@ class Patch: ...@@ -42,7 +47,27 @@ class Patch:
def patch_func_id(self): def patch_func_id(self):
return id(self.patch_func_or_cls) return id(self.patch_func_or_cls)
def set_patch_func(self, new_func_or_cls, force_patch=False, apply_wrapper=False): @staticmethod
def remove_wrappers(func):
while True:
if hasattr(func, '__wrapped__') and func.__wrapped__ is not None:
func = func.__wrapped__
elif hasattr(func, '__closure__') and func.__closure__ is not None:
func = func.__closure__[0].cell_contents
else:
return func
return func
def set_patch_func(self, new_func_or_cls=None, force_patch=False, apply_wrapper=False, remove_origin_wrappers=False):
if remove_origin_wrappers:
self.remove_origin_wrappers = True
else:
assert new_func_or_cls is not None
if new_func_or_cls is None:
return
if ( if (
apply_wrapper apply_wrapper
or (hasattr(new_func_or_cls, '__name__') and new_func_or_cls.__name__.endswith(('wrapper', 'decorator'))) or (hasattr(new_func_or_cls, '__name__') and new_func_or_cls.__name__.endswith(('wrapper', 'decorator')))
...@@ -64,6 +89,11 @@ class Patch: ...@@ -64,6 +89,11 @@ class Patch:
if self.patch_func_or_cls is not None: if self.patch_func_or_cls is not None:
final_patch_func_or_cls = self.patch_func_or_cls final_patch_func_or_cls = self.patch_func_or_cls
# remove original wrappers
if self.remove_origin_wrappers:
final_patch_func_or_cls = self.remove_wrappers(final_patch_func_or_cls)
# add new wrappers
for wrapper in self.wrappers: for wrapper in self.wrappers:
final_patch_func_or_cls = wrapper(final_patch_func_or_cls) final_patch_func_or_cls = wrapper(final_patch_func_or_cls)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment