Commit 5d997698 authored by dongcl's avatar dongcl
Browse files

megatron patch

parent 950d42b4
......@@ -17,62 +17,62 @@ def dummy_function_wrapper(func_name):
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy, apply_wrapper=False):
split_name = orig_func_name.rsplit('.', 1)
def __init__(self, orig_func_or_cls_name, new_func_or_cls, create_dummy, apply_wrapper=False):
split_name = orig_func_or_cls_name.rsplit('.', 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
self.orig_module_name, self.orig_func_or_cls_name = orig_func_or_cls_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module_name, self.orig_func_or_cls_name = split_name
self.orig_module = None
self.orig_func = None
self.orig_func_or_cls = None
self.patch_func = None
self.patch_func_or_cls = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func, apply_wrapper=apply_wrapper)
if new_func_or_cls is None:
new_func_or_cls = dummy_function_wrapper(orig_func_or_cls_name)
self.set_patch_func(new_func_or_cls, apply_wrapper=apply_wrapper)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
def orig_func_or_cls_id(self):
return id(self.orig_func_or_cls)
@property
def patch_func_id(self):
return id(self.patch_func)
return id(self.patch_func_or_cls)
def set_patch_func(self, new_func, force_patch=False, apply_wrapper=False):
def set_patch_func(self, new_func_or_cls, force_patch=False, apply_wrapper=False):
if (
apply_wrapper
or (hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')))
or (hasattr(new_func_or_cls, '__name__') and new_func_or_cls.__name__.endswith(('wrapper', 'decorator')))
):
self.wrappers.append(new_func)
self.wrappers.append(new_func_or_cls)
else:
if self.patch_func and not force_patch:
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
self.patch_func = new_func
if self.patch_func_or_cls and not force_patch:
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_or_cls_name))
self.patch_func_or_cls = new_func_or_cls
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func = Patch.parse_path(self.orig_module_name, self.orig_func_name, self.create_dummy)
self.orig_module, self.orig_func_or_cls = Patch.parse_path(self.orig_module_name, self.orig_func_or_cls_name, self.create_dummy)
final_patch_func = self.orig_func
if self.patch_func is not None:
final_patch_func = self.patch_func
final_patch_func_or_cls = self.orig_func_or_cls
if self.patch_func_or_cls is not None:
final_patch_func_or_cls = self.patch_func_or_cls
for wrapper in self.wrappers:
final_patch_func = wrapper(final_patch_func)
final_patch_func_or_cls = wrapper(final_patch_func_or_cls)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, final_patch_func)
if self.orig_func_or_cls_name is not None:
setattr(self.orig_module, self.orig_func_or_cls_name, final_patch_func_or_cls)
for key, value in sys.modules.copy().items():
if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
and id(getattr(value, self.orig_func_name)) == self.orig_func_id:
setattr(value, self.orig_func_name, final_patch_func)
if self.orig_func_or_cls_name is not None and hasattr(value, self.orig_func_or_cls_name) \
and id(getattr(value, self.orig_func_or_cls_name)) == self.orig_func_or_cls_id:
setattr(value, self.orig_func_or_cls_name, final_patch_func_or_cls)
self.is_applied = True
@staticmethod
......@@ -111,11 +111,11 @@ class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
if orig_func_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy)
else:
MegatronPatchesManager.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch)
@staticmethod
def apply_patches():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment