Unverified Commit f99f56df authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix colo parameter torch function (#1117)

parent e1620dda
...@@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager ...@@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional from typing import Optional
def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def unpack_args(*args):
if len(args) == 1:
return args[0]
return args
def replace_args(args, kwargs, new_args):
args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v
return unpack_args(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter): class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter. r"""A kind of ColoTensor to be considered as a module parameter.
...@@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook(): if ParamOpHookManager.has_hook():
if not func.__name__.startswith('__'): if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args)) if kwargs is None:
if kwargs is not None: kwargs = {}
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values()))) params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0: if len(params) > 0:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
args = ParamOpHookManager.pre_op(params, *args) new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs) ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret) ret = ParamOpHookManager.post_op(params, ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment