Commit f6107946 authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

Issue/568: Support InfiniCoreParam used in InfiniCoreModule.

parent 0b2ea12d
......@@ -16,12 +16,16 @@
from collections import OrderedDict, namedtuple
import itertools
import warnings
from typing import TYPE_CHECKING
import torch
from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if TYPE_CHECKING:
from .parameter import InfiniCoreParameter as Parameter
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
T = TypeVar('T', bound='InfiniCoreModule')
......@@ -46,7 +50,7 @@ class InfiniCoreModule:
_version: int = 1
training: bool
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'InfiniCoreParameter']]]
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'Parameter']]]
_buffers: Dict[str, Optional[torch.Tensor]]
_non_persistent_buffers_set: Set[str]
_modules: Dict[str, Optional['InfiniCoreModule']]
......@@ -84,9 +88,9 @@ class InfiniCoreModule:
d.discard(name)
params = self.__dict__.get("_parameters")
# Support both torch.nn.Parameter and InfiniCoreParameter
from .parameter import InfiniCoreParameter
if isinstance(value, (torch.nn.Parameter, InfiniCoreParameter)):
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
from .parameter import InfiniCoreParameter as Parameter
if isinstance(value, (torch.nn.Parameter, Parameter)):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call"
......@@ -102,7 +106,7 @@ class InfiniCoreModule:
if value is not None:
raise TypeError(
f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
"(torch.nn.Parameter, InfiniCoreParameter or None expected)"
"(torch.nn.Parameter, Parameter or None expected)"
)
self.register_parameter(name, value)
else:
......@@ -210,7 +214,7 @@ class InfiniCoreModule:
self._modules[name] = module
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
def register_parameter(self, name: str, param: Optional[Union[torch.nn.Parameter, 'Parameter']]) -> None:
r"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
......@@ -242,12 +246,12 @@ class InfiniCoreModule:
if param is None:
self._parameters[name] = None
else:
# Support both torch.nn.Parameter and InfiniCoreParameter
from .parameter import InfiniCoreParameter
if not isinstance(param, (torch.nn.Parameter, InfiniCoreParameter)):
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
from .parameter import InfiniCoreParameter as Parameter
if not isinstance(param, (torch.nn.Parameter, Parameter)):
raise TypeError(
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
"(torch.nn.Parameter, InfiniCoreParameter or None required)"
"(torch.nn.Parameter, Parameter or None required)"
)
self._parameters[name] = param
......@@ -557,7 +561,7 @@ class InfiniCoreModule:
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
def parameters(self, recurse: bool = True) -> Iterator[Union[torch.nn.Parameter, 'Parameter']]:
r"""Returns an iterator over module parameters.
Args:
......@@ -578,7 +582,7 @@ class InfiniCoreModule:
for name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Union[torch.nn.Parameter, 'Parameter']]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
......@@ -845,6 +849,9 @@ class InfiniCoreModule:
return False
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
# Import Parameter (InfiniCoreParameter) for type checking and creation
from .parameter import InfiniCoreParameter as Parameter
for key, param in self._parameters.items():
if param is None:
......@@ -859,6 +866,10 @@ class InfiniCoreModule:
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
# Determine the Parameter class to use based on the original parameter type
is_infinicore_param = isinstance(param, Parameter)
ParamClass = Parameter if is_infinicore_param else torch.nn.Parameter
param_grad = param.grad
if p_should_use_swap_tensors:
try:
......@@ -866,7 +877,7 @@ class InfiniCoreModule:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
param_applied = ParamClass(param_applied, requires_grad=param.requires_grad)
torch.utils.swap_tensors(param, param_applied)
except Exception as e:
if param_grad is not None:
......@@ -877,9 +888,9 @@ class InfiniCoreModule:
param.data = param_applied
out_param = param
else:
assert isinstance(param, torch.nn.Parameter)
assert isinstance(param, (torch.nn.Parameter, Parameter))
assert param.is_leaf
out_param = torch.nn.Parameter(param_applied, param.requires_grad)
out_param = ParamClass(param_applied, param.requires_grad)
self._parameters[key] = out_param
if param_grad is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment