Commit f6107946 authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

Issue/568: Support InfiniCoreParam used in InfiniCoreModule.

parent 0b2ea12d
...@@ -16,12 +16,16 @@ ...@@ -16,12 +16,16 @@
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
import itertools import itertools
import warnings import warnings
from typing import TYPE_CHECKING
import torch import torch
from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if TYPE_CHECKING:
from .parameter import InfiniCoreParameter as Parameter
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
T = TypeVar('T', bound='InfiniCoreModule') T = TypeVar('T', bound='InfiniCoreModule')
...@@ -46,7 +50,7 @@ class InfiniCoreModule: ...@@ -46,7 +50,7 @@ class InfiniCoreModule:
_version: int = 1 _version: int = 1
training: bool training: bool
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'InfiniCoreParameter']]] _parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'Parameter']]]
_buffers: Dict[str, Optional[torch.Tensor]] _buffers: Dict[str, Optional[torch.Tensor]]
_non_persistent_buffers_set: Set[str] _non_persistent_buffers_set: Set[str]
_modules: Dict[str, Optional['InfiniCoreModule']] _modules: Dict[str, Optional['InfiniCoreModule']]
...@@ -84,9 +88,9 @@ class InfiniCoreModule: ...@@ -84,9 +88,9 @@ class InfiniCoreModule:
d.discard(name) d.discard(name)
params = self.__dict__.get("_parameters") params = self.__dict__.get("_parameters")
# Support both torch.nn.Parameter and InfiniCoreParameter # Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
from .parameter import InfiniCoreParameter from .parameter import InfiniCoreParameter as Parameter
if isinstance(value, (torch.nn.Parameter, InfiniCoreParameter)): if isinstance(value, (torch.nn.Parameter, Parameter)):
if params is None: if params is None:
raise AttributeError( raise AttributeError(
"cannot assign parameters before Module.__init__() call" "cannot assign parameters before Module.__init__() call"
...@@ -102,7 +106,7 @@ class InfiniCoreModule: ...@@ -102,7 +106,7 @@ class InfiniCoreModule:
if value is not None: if value is not None:
raise TypeError( raise TypeError(
f"cannot assign '{torch.typename(value)}' as parameter '{name}' " f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
"(torch.nn.Parameter, InfiniCoreParameter or None expected)" "(torch.nn.Parameter, Parameter or None expected)"
) )
self.register_parameter(name, value) self.register_parameter(name, value)
else: else:
...@@ -210,7 +214,7 @@ class InfiniCoreModule: ...@@ -210,7 +214,7 @@ class InfiniCoreModule:
self._modules[name] = module self._modules[name] = module
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: def register_parameter(self, name: str, param: Optional[Union[torch.nn.Parameter, 'Parameter']]) -> None:
r"""Add a parameter to the module. r"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name. The parameter can be accessed as an attribute using given name.
...@@ -242,12 +246,12 @@ class InfiniCoreModule: ...@@ -242,12 +246,12 @@ class InfiniCoreModule:
if param is None: if param is None:
self._parameters[name] = None self._parameters[name] = None
else: else:
# Support both torch.nn.Parameter and InfiniCoreParameter # Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
from .parameter import InfiniCoreParameter from .parameter import InfiniCoreParameter as Parameter
if not isinstance(param, (torch.nn.Parameter, InfiniCoreParameter)): if not isinstance(param, (torch.nn.Parameter, Parameter)):
raise TypeError( raise TypeError(
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
"(torch.nn.Parameter, InfiniCoreParameter or None required)" "(torch.nn.Parameter, Parameter or None required)"
) )
self._parameters[name] = param self._parameters[name] = param
...@@ -557,7 +561,7 @@ class InfiniCoreModule: ...@@ -557,7 +561,7 @@ class InfiniCoreModule:
self.__class__.__name__, "\n\t".join(error_msgs))) self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys) return _IncompatibleKeys(missing_keys, unexpected_keys)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: def parameters(self, recurse: bool = True) -> Iterator[Union[torch.nn.Parameter, 'Parameter']]:
r"""Returns an iterator over module parameters. r"""Returns an iterator over module parameters.
Args: Args:
...@@ -578,7 +582,7 @@ class InfiniCoreModule: ...@@ -578,7 +582,7 @@ class InfiniCoreModule:
for name, param in self.named_parameters(recurse=recurse): for name, param in self.named_parameters(recurse=recurse):
yield param yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Union[torch.nn.Parameter, 'Parameter']]]:
r"""Returns an iterator over module parameters, yielding both the r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself. name of the parameter as well as the parameter itself.
...@@ -845,6 +849,9 @@ class InfiniCoreModule: ...@@ -845,6 +849,9 @@ class InfiniCoreModule:
return False return False
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
# Import Parameter (InfiniCoreParameter) for type checking and creation
from .parameter import InfiniCoreParameter as Parameter
for key, param in self._parameters.items(): for key, param in self._parameters.items():
if param is None: if param is None:
...@@ -859,6 +866,10 @@ class InfiniCoreModule: ...@@ -859,6 +866,10 @@ class InfiniCoreModule:
# subclasses may have multiple child tensors so we need to use swap_tensors # subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
# Determine the Parameter class to use based on the original parameter type
is_infinicore_param = isinstance(param, Parameter)
ParamClass = Parameter if is_infinicore_param else torch.nn.Parameter
param_grad = param.grad param_grad = param.grad
if p_should_use_swap_tensors: if p_should_use_swap_tensors:
try: try:
...@@ -866,7 +877,7 @@ class InfiniCoreModule: ...@@ -866,7 +877,7 @@ class InfiniCoreModule:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None # Decrement use count of the gradient by setting to None
param.grad = None param.grad = None
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad) param_applied = ParamClass(param_applied, requires_grad=param.requires_grad)
torch.utils.swap_tensors(param, param_applied) torch.utils.swap_tensors(param, param_applied)
except Exception as e: except Exception as e:
if param_grad is not None: if param_grad is not None:
...@@ -877,9 +888,9 @@ class InfiniCoreModule: ...@@ -877,9 +888,9 @@ class InfiniCoreModule:
param.data = param_applied param.data = param_applied
out_param = param out_param = param
else: else:
assert isinstance(param, torch.nn.Parameter) assert isinstance(param, (torch.nn.Parameter, Parameter))
assert param.is_leaf assert param.is_leaf
out_param = torch.nn.Parameter(param_applied, param.requires_grad) out_param = ParamClass(param_applied, param.requires_grad)
self._parameters[key] = out_param self._parameters[key] = out_param
if param_grad is not None: if param_grad is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment