Commit ee722eb9 authored by pengcheng888's avatar pengcheng888 Committed by zhuyue
Browse files

issue/567-只处理infinicore.Tensor,能够加载infinicore.Tensor的权重,修改了module.py paramter.py部分代码

parent f6107946
from infinicore.nn import functional from infinicore.nn import functional
from infinicore.nn.modules import * # noqa: F403
from infinicore.nn.parameter import InfiniCoreParameter as Parameter
__all__ = ["functional"] __all__ = ["functional", "Parameter"]
from .container import InfiniCoreModuleList as ModuleList
from .module import InfiniCoreModule as Module from .module import InfiniCoreModule as Module
from .module_list import InfiniCoreModuleList as ModuleList
from .parameter import InfiniCoreParameter as Parameter __all__ = ["ModuleList", "Module"]
# ============================================
# Copyright (c) 2025, InfiniCore # Copyright (c) 2025, InfiniCore
# #
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList # This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes. # but based on InfiniCoreModule for inference purposes.
from typing import List, Optional, Iterator, Union, Sequence, TypeVar
import torch
import operator import operator
from itertools import chain
from collections import OrderedDict from collections import OrderedDict
from .module import InfiniCoreModule from itertools import chain
from typing import Iterator, List, Optional, Sequence, TypeVar, Union
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule) from .module import InfiniCoreModule as Module
ModuleType = TypeVar('ModuleType', bound=Union[torch.nn.Module, 'InfiniCoreModule'])
# Define type variable for module compatibility (supports InfiniCoreModule)
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
class InfiniCoreModuleList(InfiniCoreModule):
class InfiniCoreModuleList(Module):
r"""Holds submodules in a list. r"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but InfiniCoreModuleList can be indexed like a regular Python list, but
...@@ -54,7 +55,9 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -54,7 +55,9 @@ class InfiniCoreModuleList(InfiniCoreModule):
idx += len(self) idx += len(self)
return str(idx) return str(idx)
def __getitem__(self, idx: Union[int, slice]) -> Union[ModuleType, 'InfiniCoreModuleList']: def __getitem__(
self, idx: Union[int, slice]
) -> Union[ModuleType, "InfiniCoreModuleList"]:
if isinstance(idx, slice): if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx]) return self.__class__(list(self._modules.values())[idx])
else: else:
...@@ -75,7 +78,7 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -75,7 +78,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
idx_str = self._get_abs_string_index(idx) idx_str = self._get_abs_string_index(idx)
if idx_str in self._modules: if idx_str in self._modules:
del self._modules[idx_str] del self._modules[idx_str]
# To preserve numbering, self._modules is being reconstructed with modules after deletion # To preserve numbering, self._modules is being reconstructed with modules after deletion
if len(self._modules) > 0: if len(self._modules) > 0:
str_indices = [str(i) for i in range(len(self._modules))] str_indices = [str(i) for i in range(len(self._modules))]
...@@ -87,10 +90,12 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -87,10 +90,12 @@ class InfiniCoreModuleList(InfiniCoreModule):
def __iter__(self) -> Iterator[ModuleType]: def __iter__(self) -> Iterator[ModuleType]:
return iter(self._modules.values()) return iter(self._modules.values())
def __iadd__(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList': def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
return self.extend(modules) return self.extend(modules)
def __add__(self, other: Union[Sequence[ModuleType], 'InfiniCoreModuleList']) -> 'InfiniCoreModuleList': def __add__(
self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"]
) -> "InfiniCoreModuleList":
r"""Return a new InfiniCoreModuleList by concatenating with another iterable. r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args: Args:
...@@ -101,22 +106,22 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -101,22 +106,22 @@ class InfiniCoreModuleList(InfiniCoreModule):
f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, " f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f"got {type(other).__name__}" f"got {type(other).__name__}"
) )
combined = InfiniCoreModuleList() combined = InfiniCoreModuleList()
for i, module in enumerate(chain(self, other)): for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module) combined.add_module(str(i), module)
return combined return combined
def append(self, module: ModuleType) -> 'InfiniCoreModuleList': def append(self, module: ModuleType) -> "InfiniCoreModuleList":
r"""Append a given module to the end of the list. r"""Append a given module to the end of the list.
Args: Args:
module (nn.Module or InfiniCoreModule): module to append module (InfiniCoreModule): module to append
""" """
self.add_module(str(len(self)), module) self.add_module(str(len(self)), module)
return self return self
def extend(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList': def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
r"""Append modules from a Python iterable to the end of the list. r"""Append modules from a Python iterable to the end of the list.
Args: Args:
...@@ -130,7 +135,7 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -130,7 +135,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
f"InfiniCoreModuleList.extend should be called with an " f"InfiniCoreModuleList.extend should be called with an "
f"iterable, but got {type(modules).__name__}" f"iterable, but got {type(modules).__name__}"
) )
offset = len(self) offset = len(self)
for i, module in enumerate(modules): for i, module in enumerate(modules):
self.add_module(str(offset + i), module) self.add_module(str(offset + i), module)
...@@ -141,7 +146,7 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -141,7 +146,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
Args: Args:
index (int): index to insert. index (int): index to insert.
module (nn.Module or InfiniCoreModule): module to insert module ( InfiniCoreModule): module to insert
""" """
for i in range(len(self._modules), index, -1): for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)] self._modules[str(i)] = self._modules[str(i - 1)]
...@@ -166,11 +171,11 @@ class InfiniCoreModuleList(InfiniCoreModule): ...@@ -166,11 +171,11 @@ class InfiniCoreModuleList(InfiniCoreModule):
"""Return a string representation of the ModuleList.""" """Return a string representation of the ModuleList."""
if len(self) == 0: if len(self) == 0:
return self.__class__.__name__ + "()" return self.__class__.__name__ + "()"
lines = [] lines = []
for i, module in enumerate(self): for i, module in enumerate(self):
lines.append(f"({i}): {repr(module)}") lines.append(f"({i}): {repr(module)}")
main_str = self.__class__.__name__ + "(\n " main_str = self.__class__.__name__ + "(\n "
main_str += "\n ".join(lines) + "\n)" main_str += "\n ".join(lines) + "\n)"
return main_str return main_str
......
# Copyright (c) 2025, InfiniCore # Copyright (c) 2025, InfiniCore
# #
# This file contains modified code derived from PyTorch's `torch.nn.Module` # This file contains modified code derived from PyTorch's `torch.nn.Module`
# implementation, which is licensed under the BSD 3-Clause License. # implementation, which is licensed under the BSD 3-Clause License.
# #
...@@ -13,27 +13,38 @@ ...@@ -13,27 +13,38 @@
# #
# The use of this file is governed by the BSD 3-Clause License. # The use of this file is governed by the BSD 3-Clause License.
from collections import OrderedDict, namedtuple
import itertools import itertools
import warnings import warnings
from typing import TYPE_CHECKING from collections import OrderedDict, namedtuple
from typing import (
import torch Any,
Dict,
from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List Iterator,
from torch.utils._python_dispatch import is_traceable_wrapper_subclass List,
Mapping,
if TYPE_CHECKING: Optional,
from .parameter import InfiniCoreParameter as Parameter Set,
Tuple,
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' TypeVar,
Union,
T = TypeVar('T', bound='InfiniCoreModule') overload,
)
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
import infinicore
from ...tensor import Tensor
from ..parameter import InfiniCoreParameter as Parameter
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
T = TypeVar("T", bound="InfiniCoreModule")
class _IncompatibleKeys(
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])
):
def __repr__(self): def __repr__(self):
if not self.missing_keys and not self.unexpected_keys: if not self.missing_keys and not self.unexpected_keys:
return '<All keys matched successfully>' return "<All keys matched successfully>"
return super().__repr__() return super().__repr__()
__str__ = __repr__ __str__ = __repr__
...@@ -42,18 +53,14 @@ class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpec ...@@ -42,18 +53,14 @@ class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpec
class InfiniCoreModule: class InfiniCoreModule:
r"""Base class for InfiniCore neural network modules. r"""Base class for InfiniCore neural network modules.
Your models should also subclass this class. Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure.
Modules can also contain other Modules, allowing
to nest them in a tree structure.
""" """
_version: int = 1 _version: int = 1
_parameters: Dict[str, Optional[Parameter]]
training: bool _buffers: Dict[str, Optional[Tensor]]
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'Parameter']]]
_buffers: Dict[str, Optional[torch.Tensor]]
_non_persistent_buffers_set: Set[str] _non_persistent_buffers_set: Set[str]
_modules: Dict[str, Optional['InfiniCoreModule']] _modules: Dict[str, Optional["InfiniCoreModule"]]
def __init__(self): def __init__(self):
super().__setattr__("_parameters", OrderedDict()) super().__setattr__("_parameters", OrderedDict())
...@@ -66,19 +73,22 @@ class InfiniCoreModule: ...@@ -66,19 +73,22 @@ class InfiniCoreModule:
_parameters = self.__dict__["_parameters"] _parameters = self.__dict__["_parameters"]
if name in _parameters: if name in _parameters:
return _parameters[name] return _parameters[name]
if "_buffers" in self.__dict__: if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"] _buffers = self.__dict__["_buffers"]
if name in _buffers: if name in _buffers:
return _buffers[name] return _buffers[name]
if "_modules" in self.__dict__: if "_modules" in self.__dict__:
modules = self.__dict__["_modules"] modules = self.__dict__["_modules"]
if name in modules: if name in modules:
return modules[name] return modules[name]
raise AttributeError( raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'" f"'{type(self).__name__}' object has no attribute '{name}'"
) )
def __setattr__(self, name: str, value: Union[torch.Tensor, 'InfiniCoreModule']) -> None: def __setattr__(self, name: str, value: Union[Tensor, "InfiniCoreModule"]) -> None:
def remove_from(*dicts_or_sets) -> None: def remove_from(*dicts_or_sets) -> None:
for d in dicts_or_sets: for d in dicts_or_sets:
if name in d: if name in d:
...@@ -88,13 +98,12 @@ class InfiniCoreModule: ...@@ -88,13 +98,12 @@ class InfiniCoreModule:
d.discard(name) d.discard(name)
params = self.__dict__.get("_parameters") params = self.__dict__.get("_parameters")
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter) if params is None:
from .parameter import InfiniCoreParameter as Parameter raise AttributeError(
if isinstance(value, (torch.nn.Parameter, Parameter)): "cannot assign parameters before Module.__init__() call"
if params is None: )
raise AttributeError(
"cannot assign parameters before Module.__init__() call" if isinstance(value, Parameter): # the value is of type Parameter
)
remove_from( remove_from(
self.__dict__, self.__dict__,
self._buffers, self._buffers,
...@@ -102,20 +111,21 @@ class InfiniCoreModule: ...@@ -102,20 +111,21 @@ class InfiniCoreModule:
self._non_persistent_buffers_set, self._non_persistent_buffers_set,
) )
self.register_parameter(name, value) self.register_parameter(name, value)
elif params is not None and name in params: elif name in params: # value will overwrite the name of params.
if value is not None: if not isinstance(value, Tensor):
raise TypeError( raise TypeError(
f"cannot assign '{torch.typename(value)}' as parameter '{name}' " f"cannot assign 'value' as parameter '{name}' (infinicore.nn.Parameter, Parameter or None expected)"
"(torch.nn.Parameter, Parameter or None expected)"
) )
self.register_parameter(name, value) self.register_parameter(name, value)
else: else:
modules = self.__dict__.get("_modules") modules = self.__dict__.get("_modules")
if isinstance(value, (torch.nn.Module, InfiniCoreModule)): if modules is None:
if modules is None: raise AttributeError(
raise AttributeError( "cannot assign module before Module.__init__() call"
"cannot assign module before Module.__init__() call" )
)
if isinstance(value, InfiniCoreModule):
remove_from( remove_from(
self.__dict__, self.__dict__,
self._parameters, self._parameters,
...@@ -123,32 +133,35 @@ class InfiniCoreModule: ...@@ -123,32 +133,35 @@ class InfiniCoreModule:
self._non_persistent_buffers_set, self._non_persistent_buffers_set,
) )
modules[name] = value modules[name] = value
elif modules is not None and name in modules: elif name in modules: # Do not overwrite this variable
if value is not None: raise TypeError(
raise TypeError( f"cannot assign 'value' as child module '{name}' (infinicore.nn.Module or None expected)"
f"cannot assign '{torch.typename(value)}' as child module '{name}' " )
"(torch.nn.Module or None expected)"
)
modules[name] = value
else: else:
buffers = self.__dict__.get("_buffers") buffers = self.__dict__.get("_buffers")
if buffers is not None and name in buffers: if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor): if value is not None and not isinstance(value, Tensor):
raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " raise TypeError(
"(torch.Tensor or None expected)" f"cannot assign 'value' as buffer '{name}' "
) "(torch.Tensor or None expected)"
)
buffers[name] = value buffers[name] = value
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
def register_buffer(self, name: str, tensor: Optional[torch.tensor], persistent: bool = True) -> None: def __call__(self, *input, **kwargs):
return self.forward(*input, **kwargs)
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
r"""Adds a buffer to the module. r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be This is typically used to register a buffer that should not to be
considered a model parameter.Buffers, by default, are persistent considered a model parameter.Buffers, by default, are persistent
and will be saved alongside parameters. This behavior can be changed and will be saved alongside parameters. This behavior can be changed
by setting :attr:`persistent` to ``False``. The only difference between by setting :attr:`persistent` to ``False``. The only difference between
a persistent buffer and a non-persistent buffer is that the latter a persistent buffer and a non-persistent buffer is that the latter
will not be a part of this module's :attr:`state_dict`. will not be a part of this module's :attr:`state_dict`.
Buffers can be accessed as attributes using given names. Buffers can be accessed as attributes using given names.
...@@ -163,22 +176,21 @@ class InfiniCoreModule: ...@@ -163,22 +176,21 @@ class InfiniCoreModule:
:attr:`state_dict`. :attr:`state_dict`.
""" """
if '_buffers' not in self.__dict__: if "_buffers" not in self.__dict__:
raise AttributeError( raise AttributeError("cannot assign buffer before Module.__init__() call")
"cannot assign buffer before Module.__init__() call")
elif not isinstance(name, str): elif not isinstance(name, str):
raise TypeError("buffer name should be a string. " raise TypeError("buffer name should be a string. Got {}".format("name"))
"Got {}".format(torch.typename(name))) elif "." in name:
elif '.' in name: raise KeyError('buffer name can\'t contain "."')
raise KeyError("buffer name can't contain \".\"") elif name == "":
elif name == '': raise KeyError('buffer name can\'t be empty string ""')
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers: elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name)) raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor): elif tensor is not None and not isinstance(tensor, Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' " raise TypeError(
"(torch Tensor or None required)" "cannot assign '{}' object to buffer '{}' "
.format(torch.typename(tensor), name)) "(torch Tensor or None required)".format("tensor", name)
)
else: else:
self._buffers[name] = tensor self._buffers[name] = tensor
if persistent: if persistent:
...@@ -186,8 +198,7 @@ class InfiniCoreModule: ...@@ -186,8 +198,7 @@ class InfiniCoreModule:
else: else:
self._non_persistent_buffers_set.add(name) self._non_persistent_buffers_set.add(name)
def add_module(self, name: str, module: Optional["InfiniCoreModule"]) -> None:
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
r"""Add a child module to the current module. r"""Add a child module to the current module.
The module can be accessed as an attribute using the given name. The module can be accessed as an attribute using the given name.
...@@ -201,20 +212,20 @@ class InfiniCoreModule: ...@@ -201,20 +212,20 @@ class InfiniCoreModule:
module's :attr:`children`. module's :attr:`children`.
""" """
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError(f"module name should be a string. Got {torch.typename(name)}") raise TypeError(f"module name should be a string. Got {name}")
elif '.' in name: elif "." in name:
raise KeyError(f"module name can't contain \".\", got: {name}") raise KeyError(f'module name can\'t contain ".", got: {name}')
elif name == '': elif name == "":
raise KeyError("module name can't be empty string \"\"") raise KeyError('module name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._modules: elif hasattr(self, name) and name not in self._modules:
raise KeyError(f"attribute '{name}' already exists") raise KeyError(f"attribute '{name}' already exists")
if module is not None and not isinstance(module, (torch.nn.Module, InfiniCoreModule)): if module is not None and not isinstance(module, InfiniCoreModule):
raise TypeError(f"{torch.typename(module)} is not a Module subclass") raise TypeError(f"{module} is not a Module subclass")
self._modules[name] = module self._modules[name] = module
def register_parameter(self, name: str, param: Optional[Union[torch.nn.Parameter, 'Parameter']]) -> None: def register_parameter(self, name: str, param: Parameter) -> None:
r"""Add a parameter to the module. r"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name. The parameter can be accessed as an attribute using given name.
...@@ -227,15 +238,13 @@ class InfiniCoreModule: ...@@ -227,15 +238,13 @@ class InfiniCoreModule:
are ignored. If ``None``, the parameter is **not** included in the are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`. module's :attr:`state_dict`.
""" """
if "_parameters" not in self.__dict__: if "_parameters" not in self.__dict__:
raise AttributeError( raise AttributeError(
"cannot assign parameter before Module.__init__() call" "cannot assign parameter before Module.__init__() call"
) )
elif not isinstance(name, str): elif not isinstance(name, str):
raise TypeError( raise TypeError("parameter name should be a string.")
f"parameter name should be a string. Got {torch.typename(name)}"
)
elif "." in name: elif "." in name:
raise KeyError('parameter name can\'t contain "."') raise KeyError('parameter name can\'t contain "."')
elif name == "": elif name == "":
...@@ -244,16 +253,16 @@ class InfiniCoreModule: ...@@ -244,16 +253,16 @@ class InfiniCoreModule:
raise KeyError(f"attribute '{name}' already exists") raise KeyError(f"attribute '{name}' already exists")
if param is None: if param is None:
self._parameters[name] = None self._parameters[name] = None # 竟然可以是None
else: else:
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter) if not isinstance(param, (Parameter, Tensor)):
from .parameter import InfiniCoreParameter as Parameter
if not isinstance(param, (torch.nn.Parameter, Parameter)):
raise TypeError( raise TypeError(
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " f"cannot assign 'param' object to parameter '{name}' "
"(torch.nn.Parameter, Parameter or None required)" "(infinicore.nn.Parameter, Parameter or None required)"
) )
self._parameters[name] = param self._parameters[name] = param
super().__setattr__(name, param)
def get_extra_state(self) -> Any: def get_extra_state(self) -> Any:
"""Return any extra state to include in the module's state_dict. """Return any extra state to include in the module's state_dict.
...@@ -272,7 +281,7 @@ class InfiniCoreModule: ...@@ -272,7 +281,7 @@ class InfiniCoreModule:
""" """
raise RuntimeError( raise RuntimeError(
"Reached a code path in Module.get_extra_state() that should never be called. " "Reached a code path in Module.get_extra_state() that should never be called. "
) )
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state r"""Saves module state to `destination` dictionary, containing a state
...@@ -289,29 +298,34 @@ class InfiniCoreModule: ...@@ -289,29 +298,34 @@ class InfiniCoreModule:
""" """
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
destination[prefix + name] = param if keep_vars else param.detach() destination[prefix + name] = param if keep_vars else param
for name, buf in self._buffers.items(): for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach() destination[prefix + name] = buf if keep_vars else buf
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", InfiniCoreModule.get_extra_state) is not InfiniCoreModule.get_extra_state: if (
getattr(self.__class__, "get_extra_state", InfiniCoreModule.get_extra_state)
is not InfiniCoreModule.get_extra_state
):
destination[extra_state_key] = self.get_extra_state() destination[extra_state_key] = self.get_extra_state()
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned. # back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Dict[str, Any]) T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload @overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: def state_dict(
... self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...
) -> T_destination: ...
@overload @overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: def state_dict(
... self, *, prefix: str = ..., keep_vars: bool = ...
) -> Dict[str, Any]: ...
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows. # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# Also remove the logic for arg parsing together. # Also remove the logic for arg parsing together.
def state_dict(self, *args, destination=None, prefix='', keep_vars=False): def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module. r"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are Both parameters and persistent buffers (e.g. running averages) are
...@@ -366,7 +380,7 @@ class InfiniCoreModule: ...@@ -366,7 +380,7 @@ class InfiniCoreModule:
) )
if destination is None: if destination is None:
destination = args[0] destination = args[0]
if len(args) > 1 and prefix == '': if len(args) > 1 and prefix == "":
prefix = args[1] prefix = args[1]
if len(args) > 2 and keep_vars is False: if len(args) > 2 and keep_vars is False:
keep_vars = args[2] keep_vars = args[2]
...@@ -382,9 +396,13 @@ class InfiniCoreModule: ...@@ -382,9 +396,13 @@ class InfiniCoreModule:
self._save_to_state_dict(destination, prefix, keep_vars) self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items(): for name, module in self._modules.items():
if module is not None: if module is not None:
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
return destination return destination
def set_extra_state(self, state: Any): def set_extra_state(self, state: Any):
""" """
This function is called from :func:`load_state_dict` to handle any extra state This function is called from :func:`load_state_dict` to handle any extra state
...@@ -398,10 +416,19 @@ class InfiniCoreModule: ...@@ -398,10 +416,19 @@ class InfiniCoreModule:
raise RuntimeError( raise RuntimeError(
"Reached a code path in Module.set_extra_state() that should never be called. " "Reached a code path in Module.set_extra_state() that should never be called. "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"to report this bug.") "to report this bug."
)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(
missing_keys, unexpected_keys, error_msgs): self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
r"""Copies parameters and buffers from :attr:`state_dict` into only r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
...@@ -433,50 +460,45 @@ class InfiniCoreModule: ...@@ -433,50 +460,45 @@ class InfiniCoreModule:
list, and will be reported together in list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict` :meth:`~torch.nn.Module.load_state_dict`
""" """
persistent_buffers = {
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} k: v
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items(): for name, param in local_state.items():
key = prefix + name key = prefix + name
if key in state_dict: if key in state_dict:
input_param = state_dict[key] input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.format(key, type(input_param)))
continue
# This is used to avoid copying uninitialized parameters into # input_param must be of type infinicore.Tensor
# non-lazy modules, since they dont have the hook to do the checks if not isinstance(input_param, Tensor):
# in such case, it will error when accessing the .shape attribute. raise TypeError(
is_param_lazy = torch.nn.parameter.is_lazy(param) f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ )
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] if (
(param.shape == input_param.shape)
if not is_param_lazy and input_param.shape != param.shape: and (param.dtype == input_param.dtype)
# local shape should match the one in checkpoint and (param.device == input_param.device)
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' ):
'the shape in current model is {}.' param.copy_(input_param)
.format(key, input_param.shape, param.shape)) else:
continue print(f"param '{name}' don't match input_param '{key}'")
try: setattr(self, name, input_param)
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict: elif strict:
missing_keys.append(key) missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", InfiniCoreModule.set_extra_state) is not InfiniCoreModule.set_extra_state: if (
getattr(self.__class__, "set_extra_state", InfiniCoreModule.set_extra_state)
is not InfiniCoreModule.set_extra_state
):
if extra_state_key in state_dict: if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key]) self.set_extra_state(state_dict[extra_state_key])
elif strict: elif strict:
...@@ -486,8 +508,8 @@ class InfiniCoreModule: ...@@ -486,8 +508,8 @@ class InfiniCoreModule:
if strict: if strict:
for key in state_dict.keys(): for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key: if key.startswith(prefix):
input_name = key[len(prefix):].split(".", 1) input_name = key[len(prefix) :].split(".", 1)
# Must be Module if it have attributes # Must be Module if it have attributes
if len(input_name) > 1: if len(input_name) > 1:
if input_name[0] not in self._modules: if input_name[0] not in self._modules:
...@@ -495,8 +517,7 @@ class InfiniCoreModule: ...@@ -495,8 +517,7 @@ class InfiniCoreModule:
elif input_name[0] not in local_state: elif input_name[0] not in local_state:
unexpected_keys.append(key) unexpected_keys.append(key)
def load_state_dict(self, state_dict: Mapping[str, Any], def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned the keys of :attr:`state_dict` must exactly match the keys returned
...@@ -520,28 +541,40 @@ class InfiniCoreModule: ...@@ -520,28 +541,40 @@ class InfiniCoreModule:
``RuntimeError``. ``RuntimeError``.
""" """
if not isinstance(state_dict, Mapping): if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) raise TypeError(
"Expected state_dict to be dict-like, got {}.".format(type(state_dict))
)
missing_keys: List[str] = [] missing_keys: List[str] = []
unexpected_keys: List[str] = [] unexpected_keys: List[str] = []
error_msgs: List[str] = [] error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None) metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict) state_dict = OrderedDict(state_dict)
if metadata is not None: if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined] state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, local_state_dict, prefix=''): def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( module._load_from_state_dict(
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
child_prefix = prefix + name + '.' child_prefix = prefix + name + "."
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} child_state_dict = {
load(child, child_state_dict, child_prefix) k: v
for k, v in local_state_dict.items()
if k.startswith(child_prefix)
}
load(child, child_state_dict, child_prefix) # noqa: F821
load(self, state_dict) load(self, state_dict)
del load del load
...@@ -549,19 +582,28 @@ class InfiniCoreModule: ...@@ -549,19 +582,28 @@ class InfiniCoreModule:
if strict: if strict:
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
error_msgs.insert( error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format( 0,
', '.join('"{}"'.format(k) for k in unexpected_keys))) "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
),
)
if len(missing_keys) > 0: if len(missing_keys) > 0:
error_msgs.insert( error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format( 0,
', '.join('"{}"'.format(k) for k in missing_keys))) "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
),
)
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError(
self.__class__.__name__, "\n\t".join(error_msgs))) "Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
return _IncompatibleKeys(missing_keys, unexpected_keys) return _IncompatibleKeys(missing_keys, unexpected_keys)
def parameters(self, recurse: bool = True) -> Iterator[Union[torch.nn.Parameter, 'Parameter']]: def parameters(self, recurse: bool = True) -> Iterator["Parameter"]:
r"""Returns an iterator over module parameters. r"""Returns an iterator over module parameters.
Args: Args:
...@@ -582,7 +624,9 @@ class InfiniCoreModule: ...@@ -582,7 +624,9 @@ class InfiniCoreModule:
for name, param in self.named_parameters(recurse=recurse): for name, param in self.named_parameters(recurse=recurse):
yield param yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Union[torch.nn.Parameter, 'Parameter']]]: def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, "Parameter"]]:
r"""Returns an iterator over module parameters, yielding both the r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself. name of the parameter as well as the parameter itself.
...@@ -604,12 +648,12 @@ class InfiniCoreModule: ...@@ -604,12 +648,12 @@ class InfiniCoreModule:
""" """
gen = self._named_members( gen = self._named_members(
lambda module: module._parameters.items(), lambda module: module._parameters.items(), prefix=prefix, recurse=recurse
prefix=prefix, recurse=recurse) )
for elem in gen: for elem in gen:
yield elem yield elem
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
r"""Returns an iterator over module buffers. r"""Returns an iterator over module buffers.
Args: Args:
...@@ -630,7 +674,9 @@ class InfiniCoreModule: ...@@ -630,7 +674,9 @@ class InfiniCoreModule:
for name, buf in self.named_buffers(recurse=recurse): for name, buf in self.named_buffers(recurse=recurse):
yield buf yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself. name of the buffer as well as the buffer itself.
...@@ -660,10 +706,10 @@ class InfiniCoreModule: ...@@ -660,10 +706,10 @@ class InfiniCoreModule:
if k in module._non_persistent_buffers_set: if k in module._non_persistent_buffers_set:
continue continue
memo.add(v) memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k name = module_prefix + ("." if module_prefix else "") + k
yield (name, v) yield (name, v)
def _named_members(self, get_members_fn, prefix='', recurse=True): def _named_members(self, get_members_fn, prefix="", recurse=True):
r"""Helper method to yield members with their names.""" r"""Helper method to yield members with their names."""
memo = set() memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
...@@ -673,10 +719,10 @@ class InfiniCoreModule: ...@@ -673,10 +719,10 @@ class InfiniCoreModule:
if v is None or v in memo: if v is None or v in memo:
continue continue
memo.add(v) memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k name = module_prefix + ("." if module_prefix else "") + k
yield (name, v) yield (name, v)
def modules(self) -> Iterator['InfiniCoreModule']: def modules(self) -> Iterator["InfiniCoreModule"]:
r"""Returns an iterator over all modules in the network. r"""Returns an iterator over all modules in the network.
Yields: Yields:
...@@ -704,7 +750,12 @@ class InfiniCoreModule: ...@@ -704,7 +750,12 @@ class InfiniCoreModule:
for name, module in self.named_modules(): for name, module in self.named_modules():
yield module yield module
def named_modules(self, memo: Optional[Set['InfiniCoreModule']] = None, prefix: str = '', remove_duplicate: bool = True): def named_modules(
self,
memo: Optional[Set["InfiniCoreModule"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
r"""Returns an iterator over all modules in the network, yielding r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself. both the name of the module as well as the module itself.
...@@ -746,18 +797,20 @@ class InfiniCoreModule: ...@@ -746,18 +797,20 @@ class InfiniCoreModule:
for name, module in self._modules.items(): for name, module in self._modules.items():
if module is None: if module is None:
continue continue
submodule_prefix = prefix + ('.' if prefix else '') + name submodule_prefix = prefix + ("." if prefix else "") + name
# Handle both InfiniCoreModule and torch.nn.Module # Handle both InfiniCoreModule and torch.nn.Module
if isinstance(module, InfiniCoreModule): if isinstance(module, InfiniCoreModule):
for m in module.named_modules(memo, submodule_prefix, remove_duplicate): for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m yield m
elif isinstance(module, torch.nn.Module): elif isinstance(module, infinicore.nn.Module):
# For torch.nn.Module, use its named_modules method # For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples # torch.nn.Module.named_modules returns (name, module) tuples
for sub_name, sub_module in module.named_modules(prefix=submodule_prefix, remove_duplicate=remove_duplicate): for sub_name, sub_module in module.named_modules(
prefix=submodule_prefix, remove_duplicate=remove_duplicate
):
yield (sub_name, sub_module) yield (sub_name, sub_module)
def children(self) -> Iterator[Union['InfiniCoreModule', torch.nn.Module]]: def children(self) -> Iterator["InfiniCoreModule"]:
r"""Returns an iterator over immediate children modules. r"""Returns an iterator over immediate children modules.
Yields: Yields:
...@@ -766,7 +819,9 @@ class InfiniCoreModule: ...@@ -766,7 +819,9 @@ class InfiniCoreModule:
for name, module in self.named_children(): for name, module in self.named_children():
yield module yield module
def named_children(self) -> Iterator[Tuple[str, Union['InfiniCoreModule', torch.nn.Module]]]: def named_children(
self,
) -> Iterator[Tuple[str, "InfiniCoreModule"]]:
r"""Returns an iterator over immediate children modules, yielding both r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself. the name of the module as well as the module itself.
...@@ -787,169 +842,16 @@ class InfiniCoreModule: ...@@ -787,169 +842,16 @@ class InfiniCoreModule:
memo.add(module) memo.add(module)
yield name, module yield name, module
def train(self: T, mode: bool = True) -> T:
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T: def eval(self: T) -> T:
r"""Sets the module in evaluation mode. r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns: Returns:
Module: self Module: self
""" """
return self.train(False) pass
def _apply(self, fn, recurse=True): def _apply(self, fn, recurse=True):
if recurse: raise KeyError("not support")
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
# Import Parameter (InfiniCoreParameter) for type checking and creation
from .parameter import InfiniCoreParameter as Parameter
for key, param in self._parameters.items():
if param is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
# Determine the Parameter class to use based on the original parameter type
is_infinicore_param = isinstance(param, Parameter)
ParamClass = Parameter if is_infinicore_param else torch.nn.Parameter
param_grad = param.grad
if p_should_use_swap_tensors:
try:
if param_grad is not None:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = ParamClass(param_applied, requires_grad=param.requires_grad)
torch.utils.swap_tensors(param, param_applied)
except Exception as e:
if param_grad is not None:
param.grad = param_grad
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
out_param = param
elif p_should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, (torch.nn.Parameter, Parameter))
assert param.is_leaf
out_param = ParamClass(param_applied, param.requires_grad)
self._parameters[key] = out_param
if param_grad is not None:
with torch.no_grad():
grad_applied = fn(param_grad)
g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied)
if p_should_use_swap_tensors:
grad_applied.requires_grad_(param_grad.requires_grad)
try:
torch.utils.swap_tensors(param_grad, grad_applied)
except Exception as e:
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e
out_param.grad = param_grad
elif g_should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param_grad.is_leaf
out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) raise KeyError("not support")
if dtype is not None:
if not (dtype.is_floating_point or dtype.is_complex):
raise TypeError('nn.Module.to only accepts floating point or complex '
f'dtypes, but got desired dtype={dtype}')
if dtype.is_complex:
warnings.warn(
"Complex modules are a new feature under active development whose design may change, "
"and some modules might not work as expected when using complex tensors as parameters or buffers. ")
def convert(t):
try:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
)
except NotImplementedError as e:
if str(e) == "Cannot copy out of meta tensor; no data!":
raise NotImplementedError(
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
f"when moving module from meta to a different device."
) from None
else:
raise
return self._apply(convert)
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
import torch
from typing import Optional
from collections import OrderedDict
class InfiniCoreParameter(torch.Tensor):
r"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`InfiniCoreModule` s - when they're
assigned as Module attributes they are automatically added to the list of
its parameters, and will appear e.g. in :meth:`~InfiniCoreModule.parameters` iterator.
Assigning a Tensor doesn't have such effect. This is because one might
want to cache some temporary state, like last hidden state of the RNN, in
the model. If there was no such class as :class:`InfiniCoreParameter`, these
temporaries would get registered too.
Args:
data (Tensor, optional): parameter tensor. If None, creates an empty tensor.
requires_grad (bool, optional): if the parameter requires gradient. Note that
the torch.no_grad() context does NOT affect the default behavior of
Parameter creation--the Parameter will still have `requires_grad=True` in
:class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more
details. Default: `True`
Example::
>>> import torch
>>> from infinicore.nn.modules import InfiniCoreModule, InfiniCoreParameter
>>>
>>> class MyModule(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.weight = InfiniCoreParameter(torch.randn(10, 5))
... self.bias = InfiniCoreParameter(torch.randn(5))
...
>>> module = MyModule()
>>> for param in module.parameters():
... print(param.shape)
torch.Size([10, 5])
torch.Size([5])
"""
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True):
if data is None:
data = torch.empty(0)
# Handle standard torch.Tensor or InfiniCoreParameter
if type(data) is torch.Tensor or type(data) is InfiniCoreParameter:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return torch.Tensor._make_subclass(cls, data, requires_grad)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data):
raise RuntimeError(
f"Creating a InfiniCoreParameter from an instance of type {type(data).__name__} "
"requires that detach() returns an instance of the same type, but return "
f"type {type(t).__name__} was found instead. To use the type as a "
"InfiniCoreParameter, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation."
)
t._is_param = True
return t
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
# are still considered that custom tensor type and these methods will not be called for them.
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(
self.data.clone(memory_format=torch.preserve_format), self.requires_grad
)
memo[id(self)] = result
return result
def __repr__(self):
return "InfiniCoreParameter containing:\n" + super().__repr__()
def __reduce_ex__(self, proto):
# Simplified version for serialization
# In a full implementation, you might want to handle hooks and state
state = getattr(self, '_state', None)
hooks = OrderedDict()
if not state:
return (
_rebuild_parameter,
(self.data, self.requires_grad, hooks),
)
return (
_rebuild_parameter_with_state,
(self.data, self.requires_grad, hooks, state),
)
# Note: __torch_function__ is handled by the Tensor base class
# We don't need to override it for standard Parameter behavior
def _rebuild_parameter(data, requires_grad, hooks):
"""Rebuild a parameter from serialized data."""
param = InfiniCoreParameter(data, requires_grad)
# Apply hooks if any (simplified - full implementation would restore hooks)
return param
def _rebuild_parameter_with_state(data, requires_grad, hooks, state):
"""Rebuild a parameter with extra state from serialized data."""
param = InfiniCoreParameter(data, requires_grad)
param._state = state
# Apply hooks if any (simplified - full implementation would restore hooks)
return param
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
from ..tensor import Tensor
class InfiniCoreParameter(Tensor):
r"""A kind of Tensor that is to be considered a module parameter."""
def __init__(self, data=None):
if not isinstance(data, Tensor):
raise ValueError("The `data` variable must be of type `infinicore.Tensor`.")
super().__init__(data._underlying)
def __repr__(self):
return "Parameter containing:\n" + super().__repr__()
def __deepcopy__(self, memo):
raise ValueError("not supported!")
def __reduce_ex__(self, proto):
raise ValueError("not supported!")
import safetensors.torch
import torch
import torch.nn as nn
import safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
# 使用临时目录,如果不存在则自动创建
save_dir = os.path.join(os.path.dirname(__file__), '../../tmp')
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_modulelist_with_param.safetensors")
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class TorchModuleListNet(nn.Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 torch.nn.ModuleList
self.layers = nn.ModuleList([
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
])
# 自定义 Parameter
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
# 应用自定义参数和 buffer
x = x * self.scale + self.offset
return x
# ===== 保存 Torch 模型 =====
torch_model = TorchModuleListNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
print("✓ PyTorch 模型已保存")
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer = TorchModuleListNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.rand(1, 3, 8, 8)
torch_model_out = torch_model_infer(input)
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from nn.modules import Module, ModuleList
class InfiniCoreModuleListNet(Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 ModuleList
self.layers = ModuleList([
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
])
# 保持与 Torch 模型一致的自定义参数和 buffer
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreModuleListNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
# ============================================================
# 4. 对比结果
# ============================================================
diff = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"✓ ModuleList 与 Torch 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ ModuleList 与 Torch 精度一致.")
else:
print("✗ ModuleList 与 Torch 精度存在差异.")
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print("\n=== 测试 ModuleList 基本功能 ===")
# 测试 1: 创建和访问
module_list = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
print(f"✓ 创建 ModuleList,长度: {len(module_list)}")
print(f"✓ 访问第一个模块: {type(module_list[0]).__name__}")
print(f"✓ 访问第二个模块: {type(module_list[1]).__name__}")
# 测试 2: append
module_list.append(nn.Softmax(dim=-1))
print(f"✓ append 后长度: {len(module_list)}")
# 测试 3: extend
module_list.extend([nn.Dropout(0.1), nn.Linear(5, 1)])
print(f"✓ extend 后长度: {len(module_list)}")
# 测试 4: 迭代
print("✓ 迭代 ModuleList:")
for i, module in enumerate(module_list):
print(f" [{i}] {type(module).__name__}")
# 测试 5: 索引访问
print(f"✓ 索引访问 module_list[0]: {type(module_list[0]).__name__}")
# 测试 6: state_dict
state_dict = module_list.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}")
print(f"✓ state_dict 包含模块参数: {any('0.' in k for k in state_dict.keys())}")
# 测试 7: 使用 ModuleList 的模型
class TestNet(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
test_model = TestNet()
test_input = torch.randn(2, 10)
test_output = test_model.forward(test_input)
print(f"✓ TestNet 输入形状: {test_input.shape}, 输出形状: {test_output.shape}")
# 测试 8: __add__ 方法
ml1 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
ml2 = ModuleList([nn.Linear(5, 3), nn.Sigmoid()])
ml3 = ml1 + ml2
print(f"✓ __add__ 方法测试: {len(ml1)} + {len(ml2)} = {len(ml3)}")
assert len(ml3) == 4, "合并后的长度应该为 4"
# 测试 9: pop 方法
ml4 = ModuleList([nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3)])
popped = ml4.pop()
print(f"✓ pop 方法测试: 弹出后长度 {len(ml4)}, 弹出模块类型 {type(popped).__name__}")
assert len(ml4) == 2, "pop 后长度应该为 2"
assert isinstance(popped, nn.Linear), "弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
repr_str = repr(ml5)
print(f"✓ __repr__ 方法测试: 输出包含类名和模块信息")
assert "ModuleList" in repr_str or "InfiniCoreModuleList" in repr_str, "repr 应该包含类名"
assert "Linear" in repr_str, "repr 应该包含模块信息"
print(repr_str)
print("\n=== 所有测试通过! ===")
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print("\n=== 前向传播集成测试 ===")
# 使用 ModuleList 创建一个简单的模型
class TorchModuleListModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
class InfiniCoreModuleListModel(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# 创建模型
torch_model_forward = TorchModuleListModel()
infinicore_model_forward = InfiniCoreModuleListModel()
# 复制权重(确保初始权重一致)
infinicore_model_forward.load_state_dict(torch_model_forward.state_dict(), strict=False)
# 设置为评估模式
torch_model_forward.eval()
infinicore_model_forward.eval()
# 创建测试输入
test_input = torch.randn(2, 10)
# 前向传播
with torch.no_grad():
torch_output = torch_model_forward(test_input)
infinicore_output = infinicore_model_forward.forward(test_input)
# 对比结果
diff = (infinicore_output - torch_output).abs().max().item()
print(f"✓ 前向传播测试 - 输入形状: {test_input.shape}")
print(f"✓ Torch 输出形状: {torch_output.shape}, 均值: {torch_output.detach().numpy().mean():.8f}")
print(f"✓ InfiniCore 输出形状: {infinicore_output.shape}, 均值: {infinicore_output.detach().numpy().mean():.8f}")
print(f"✓ 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!")
else:
print("✗ 前向传播集成测试失败:存在差异")
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print("\n=== 混合模块兼容性测试 ===")
# 创建一个自定义的 InfiniCore 模块
class CustomLinear(Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
return x @ self.weight.t() + self.bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list = ModuleList([
nn.Linear(10, 5), # PyTorch 模块
CustomLinear(5, 3), # 自定义 InfiniCore 模块
nn.ReLU(), # PyTorch 模块
])
print(f"✓ 创建混合 ModuleList,长度: {len(mixed_list)}")
print(f"✓ 模块类型: {[type(m).__name__ for m in mixed_list]}")
# 测试参数注册
param_count = sum(1 for _ in mixed_list.parameters())
print(f"✓ 参数数量: {param_count}")
assert param_count == 4, f"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为 {param_count}"
# 测试 state_dict
mixed_state_dict = mixed_list.state_dict()
print(f"✓ state_dict 键数量: {len(mixed_state_dict)}")
assert len(mixed_state_dict) >= 4, "state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed = torch.randn(2, 10)
with torch.no_grad():
x = test_input_mixed
for module in mixed_list:
x = module.forward(x)
print(f"✓ 混合模块前向传播成功,输出形状: {x.shape}")
print("✓ 混合模块兼容性测试通过!")
import safetensors.torch
import torch
import torch.nn as nn
import safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
save_dir = os.path.join(os.path.dirname(__file__), '../../tmp')
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_convnet_with_param.safetensors")
# ============================================================
# 1. 使用 PyTorch 定义并保存模型
# ============================================================
print("===== 开始 CPU 一致性测试 =====")
class TorchConvNet(nn.Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 主体网络
self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(hidden_ch)
self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(hidden_ch)
self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1)
self.relu = nn.ReLU()
# 自定义 Parameter
self.scale = nn.Parameter(torch.ones(1) * 0.5)
# 注册一个 buffer
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.conv3(x)
# 应用自定义参数和 buffer
x = x * self.scale + self.offset
return x
# ===== 保存 Torch 模型 =====
torch_model = TorchConvNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer = TorchConvNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.rand(1, 3, 8, 8)
torch_model_out = torch_model_infer(input)
# ============================================================
# 3. 使用 infiniCore.nn.module 加载并推理
# ============================================================
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
from nn import Module
class InfiniCoreConvNet(Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(hidden_ch)
self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(hidden_ch)
self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1)
self.relu = nn.ReLU()
# 保持与 Torch 模型一致的自定义参数和 buffer
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.conv3(x)
x = x * self.scale + self.offset
return x
# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreConvNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
# ============================================================
# 4. 对比结果
# ============================================================
diff_cpu = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"InfiniCoreModule 与 Torch (CPU) 最大误差: {diff_cpu:.6e}")
if diff_cpu < 1e-9:
print("CPU 模式下 InfiniCore 与 Torch 输出完全一致.")
else:
print("CPU 模式下输出存在差异.")
# ============================================================
# 5. GPU 一致性测试(可选)
# ============================================================
if torch.cuda.is_available():
print("\n===== 开始 GPU 一致性测试 =====")
# 将模型与输入都迁移到 GPU
torch_model_infer_gpu = TorchConvNet().to("cuda")
torch_model_infer_gpu.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer_gpu.eval()
infinicore_model_infer_gpu = InfiniCoreConvNet().to("cuda")
infinicore_model_infer_gpu.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer_gpu.eval()
# 生成 GPU 输入
input_gpu = input.to("cuda")
# 分别前向推理
torch_out_gpu = torch_model_infer_gpu(input_gpu)
infinicore_out_gpu = infinicore_model_infer_gpu.forward(input_gpu)
# 结果比较
diff_gpu = (infinicore_out_gpu - torch_out_gpu).abs().max().item()
print(f"InfiniCoreModule 与 Torch (GPU) 最大误差: {diff_gpu:.6e}")
if diff_gpu < 1e-9:
print("GPU 模式下 InfiniCore 与 Torch 输出完全一致.")
else:
print("GPU 模式下输出存在差异.")
else:
print("\n 未检测到 GPU,跳过 GPU 一致性测试。")
\ No newline at end of file
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import os
import sys
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_convnet_with_param.safetensors")
import infinicore # noqa: E402
from infinicore.nn import Module # noqa: E402
# ============================================================
# 1. 定义模型
# ============================================================
device_str = "cuda"
class InfiniCoreNet(Module):
def __init__(self):
super().__init__()
self.a = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
self.b = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
def forward(self):
return infinicore.add(self.a, self.b)
infinicore_model_infer = InfiniCoreNet()
# ============================================================
# 2. 加载权重
# ============================================================
params_dict = {
"a": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
),
"b": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
),
}
infinicore_model_infer.load_state_dict(params_dict)
# ============================================================
# 3. 计算
# ============================================================
infinicore_model_out = infinicore_model_infer()
ref_out = infinicore.add(params_dict["a"], params_dict["b"])
# ============================================================
# 4. 对比结果
# ============================================================
print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ")
infinicore_model_out.debug()
ref_out.debug()
# ============================================================
# 5. to测试,buffer测试
# ============================================================
# 等待添加
import os
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import sys
import safetensors
import safetensors.torch
import torch
import torch.nn as nn
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
# 使用临时目录,如果不存在则自动创建
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_modulelist_with_param.safetensors")
def test():
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class TorchModuleListNet(nn.Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 torch.nn.ModuleList
self.layers = nn.ModuleList(
[
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
]
)
# 自定义 Parameter
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
# 应用自定义参数和 buffer
x = x * self.scale + self.offset
return x
# ===== 保存 Torch 模型 =====
torch_model = TorchModuleListNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
print("✓ PyTorch 模型已保存")
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer = TorchModuleListNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.rand(1, 3, 8, 8)
torch_model_out = torch_model_infer(input)
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from nn.modules import Module, ModuleList
class InfiniCoreModuleListNet(Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 ModuleList
self.layers = ModuleList(
[
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
]
)
# 保持与 Torch 模型一致的自定义参数和 buffer
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreModuleListNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
# ============================================================
# 4. 对比结果
# ============================================================
diff = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"✓ ModuleList 与 Torch 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ ModuleList 与 Torch 精度一致.")
else:
print("✗ ModuleList 与 Torch 精度存在差异.")
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print("\n=== 测试 ModuleList 基本功能 ===")
# 测试 1: 创建和访问
module_list = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
print(f"✓ 创建 ModuleList,长度: {len(module_list)}")
print(f"✓ 访问第一个模块: {type(module_list[0]).__name__}")
print(f"✓ 访问第二个模块: {type(module_list[1]).__name__}")
# 测试 2: append
module_list.append(nn.Softmax(dim=-1))
print(f"✓ append 后长度: {len(module_list)}")
# 测试 3: extend
module_list.extend([nn.Dropout(0.1), nn.Linear(5, 1)])
print(f"✓ extend 后长度: {len(module_list)}")
# 测试 4: 迭代
print("✓ 迭代 ModuleList:")
for i, module in enumerate(module_list):
print(f" [{i}] {type(module).__name__}")
# 测试 5: 索引访问
print(f"✓ 索引访问 module_list[0]: {type(module_list[0]).__name__}")
# 测试 6: state_dict
state_dict = module_list.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}")
print(f"✓ state_dict 包含模块参数: {any('0.' in k for k in state_dict.keys())}")
# 测试 7: 使用 ModuleList 的模型
class TestNet(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
test_model = TestNet()
test_input = torch.randn(2, 10)
test_output = test_model.forward(test_input)
print(f"✓ TestNet 输入形状: {test_input.shape}, 输出形状: {test_output.shape}")
# 测试 8: __add__ 方法
ml1 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
ml2 = ModuleList([nn.Linear(5, 3), nn.Sigmoid()])
ml3 = ml1 + ml2
print(f"✓ __add__ 方法测试: {len(ml1)} + {len(ml2)} = {len(ml3)}")
assert len(ml3) == 4, "合并后的长度应该为 4"
# 测试 9: pop 方法
ml4 = ModuleList([nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3)])
popped = ml4.pop()
print(
f"✓ pop 方法测试: 弹出后长度 {len(ml4)}, 弹出模块类型 {type(popped).__name__}"
)
assert len(ml4) == 2, "pop 后长度应该为 2"
assert isinstance(popped, nn.Linear), "弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
repr_str = repr(ml5)
print(f"✓ __repr__ 方法测试: 输出包含类名和模块信息")
assert "ModuleList" in repr_str or "InfiniCoreModuleList" in repr_str, (
"repr 应该包含类名"
)
assert "Linear" in repr_str, "repr 应该包含模块信息"
print(repr_str)
print("\n=== 所有测试通过! ===")
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print("\n=== 前向传播集成测试 ===")
# 使用 ModuleList 创建一个简单的模型
class TorchModuleListModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)]
)
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
class InfiniCoreModuleListModel(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# 创建模型
torch_model_forward = TorchModuleListModel()
infinicore_model_forward = InfiniCoreModuleListModel()
# 复制权重(确保初始权重一致)
infinicore_model_forward.load_state_dict(
torch_model_forward.state_dict(), strict=False
)
# 设置为评估模式
torch_model_forward.eval()
infinicore_model_forward.eval()
# 创建测试输入
test_input = torch.randn(2, 10)
# 前向传播
with torch.no_grad():
torch_output = torch_model_forward(test_input)
infinicore_output = infinicore_model_forward.forward(test_input)
# 对比结果
diff = (infinicore_output - torch_output).abs().max().item()
print(f"✓ 前向传播测试 - 输入形状: {test_input.shape}")
print(
f"✓ Torch 输出形状: {torch_output.shape}, 均值: {torch_output.detach().numpy().mean():.8f}"
)
print(
f"✓ InfiniCore 输出形状: {infinicore_output.shape}, 均值: {infinicore_output.detach().numpy().mean():.8f}"
)
print(f"✓ 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!")
else:
print("✗ 前向传播集成测试失败:存在差异")
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print("\n=== 混合模块兼容性测试 ===")
# 创建一个自定义的 InfiniCore 模块
class CustomLinear(Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
return x @ self.weight.t() + self.bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list = ModuleList(
[
nn.Linear(10, 5), # PyTorch 模块
CustomLinear(5, 3), # 自定义 InfiniCore 模块
nn.ReLU(), # PyTorch 模块
]
)
print(f"✓ 创建混合 ModuleList,长度: {len(mixed_list)}")
print(f"✓ 模块类型: {[type(m).__name__ for m in mixed_list]}")
# 测试参数注册
param_count = sum(1 for _ in mixed_list.parameters())
print(f"✓ 参数数量: {param_count}")
assert param_count == 4, (
f"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为 {param_count}"
)
# 测试 state_dict
mixed_state_dict = mixed_list.state_dict()
print(f"✓ state_dict 键数量: {len(mixed_state_dict)}")
assert len(mixed_state_dict) >= 4, "state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed = torch.randn(2, 10)
with torch.no_grad():
x = test_input_mixed
for module in mixed_list:
x = module.forward(x)
print(f"✓ 混合模块前向传播成功,输出形状: {x.shape}")
print("✓ 混合模块兼容性测试通过!")
import safetensors.torch
import torch
import torch.nn as nn
import safetensors
# ============================================================ # ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 # 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================ # ============================================================
import sys
import os import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
save_dir = os.path.join(os.path.dirname(__file__), '../../tmp') import torch
import torch.nn as nn
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors") save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors")
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.Parameter)
# ============================================================
class TorchParameterNet(nn.Module): import infinicore # noqa: E402
def __init__(self, in_features=10, out_features=5): from infinicore.nn import Module, Parameter # noqa: E402
device_str = "cuda"
class InfiniCoreParameterNet(Module):
def __init__(self):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.a = infinicore.nn.Parameter(
self.bias = nn.Parameter(torch.randn(out_features)) infinicore.empty(
self.scale = nn.Parameter(torch.ones(1) * 0.5) (1, 2, 3), dtype=infinicore.float32, device=infinicore.device("cpu", 0)
self.register_buffer("offset", torch.tensor(0.1)) )
)
def forward(self, x): def forward(self, x):
return (x @ self.weight.t() + self.bias) * self.scale + self.offset return infinicore.add(self.a, x)
# ===== 保存 Torch 模型 =====
torch_model = TorchParameterNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
print("✓ PyTorch 模型已保存")
infinicore_model_infer = InfiniCoreParameterNet()
# ============================================================ # ============================================================
# 2. 使用 torch 方式加载并推理 # 2. 加载权重
# ============================================================ # ============================================================
params_dict = {
"a": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
}
infinicore_model_infer.load_state_dict(params_dict)
torch_model_infer = TorchParameterNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.randn(2, 10)
torch_model_out = torch_model_infer(input)
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
# ============================================================ # ============================================================
# 3. 使用 Parameter 加载并推理 # 3. 计算
# ============================================================ # ============================================================
x = infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
from nn.modules import Module, Parameter infinicore_model_out = infinicore_model_infer(x)
ref_out = infinicore.add(params_dict["a"], x)
class InfiniCoreParameterNet(Module):
def __init__(self, in_features=10, out_features=5):
super().__init__()
# 使用 Parameter 替代 torch.nn.Parameter
self.weight = Parameter(torch.randn(out_features, in_features))
self.bias = Parameter(torch.randn(out_features))
self.scale = Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
return (x @ self.weight.t() + self.bias) * self.scale + self.offset
# ===== 使用 InfiniCoreParameterNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreParameterNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
# ============================================================ # ============================================================
# 4. 对比结果 # 4. 对比结果
# ============================================================ # ============================================================
print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ")
infinicore_model_out.debug()
ref_out.debug()
diff = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"✓ Parameter 与 Torch 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ Parameter 与 Torch 精度一致.")
else:
print("✗ Parameter 与 Torch 精度存在差异.")
# ============================================================ # ============================================================
# 5. 测试 Parameter 的基本功能 # 5. 测试 Parameter 的基本功能
...@@ -93,28 +73,37 @@ else: ...@@ -93,28 +73,37 @@ else:
print("\n=== 测试 Parameter 基本功能 ===") print("\n=== 测试 Parameter 基本功能 ===")
# 测试 1: 创建 Parameter # 测试 1: 创建 Parameter
param1 = Parameter(torch.randn(5, 10)) param1 = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
)
print(f"✓ 创建 Parameter,形状: {param1.shape}") print(f"✓ 创建 Parameter,形状: {param1.shape}")
# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名) # 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名)
from nn.modules.parameter import InfiniCoreParameter
assert isinstance(param1, (Parameter, InfiniCoreParameter)), "应该是 Parameter 类型"
assert isinstance(param1, torch.Tensor), "应该是 torch.Tensor 的子类"
# 测试 2: requires_grad assert isinstance(param1, infinicore.nn.Parameter), "应该是 Parameter 类型"
param2 = Parameter(torch.randn(3, 4), requires_grad=False) assert isinstance(param1, infinicore.Tensor), "应该是 torch.Tensor 的子类"
print(f"✓ 创建 requires_grad=False 的 Parameter: {param2.requires_grad}")
assert not param2.requires_grad, "requires_grad 应该为 False"
param3 = Parameter(torch.randn(3, 4), requires_grad=True)
print(f"✓ 创建 requires_grad=True 的 Parameter: {param3.requires_grad}")
assert param3.requires_grad, "requires_grad 应该为 True"
# 测试 3: 自动注册到 Module # 测试 3: 自动注册到 Module
class TestModule(Module): class TestModule(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.weight = Parameter(torch.randn(5, 10)) self.weight = infinicore.nn.Parameter(
self.bias = Parameter(torch.randn(5)) infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
self.bias = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
test_module = TestModule() test_module = TestModule()
param_count = sum(1 for _ in test_module.parameters()) param_count = sum(1 for _ in test_module.parameters())
...@@ -129,8 +118,8 @@ print("✓ 参数可以通过属性访问") ...@@ -129,8 +118,8 @@ print("✓ 参数可以通过属性访问")
# 测试 5: state_dict # 测试 5: state_dict
state_dict = test_module.state_dict() state_dict = test_module.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}") print(f"✓ state_dict 键数量: {len(state_dict)}")
assert 'weight' in state_dict, "state_dict 应该包含 weight" assert "weight" in state_dict, "state_dict 应该包含 weight"
assert 'bias' in state_dict, "state_dict 应该包含 bias" assert "bias" in state_dict, "state_dict 应该包含 bias"
print(f"✓ state_dict 键: {list(state_dict.keys())}") print(f"✓ state_dict 键: {list(state_dict.keys())}")
# 测试 6: __repr__ # 测试 6: __repr__
...@@ -139,46 +128,21 @@ print(f"✓ __repr__ 方法: 输出包含类名") ...@@ -139,46 +128,21 @@ print(f"✓ __repr__ 方法: 输出包含类名")
assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名" assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名"
print(repr_str[:100] + "...") print(repr_str[:100] + "...")
# 测试 7: 与 torch.nn.Parameter 兼容性
class MixedModule(Module):
def __init__(self):
super().__init__()
self.torch_param = nn.Parameter(torch.randn(3, 4))
self.infinicore_param = Parameter(torch.randn(3, 4))
mixed_module = MixedModule()
mixed_param_count = sum(1 for _ in mixed_module.parameters())
print(f"✓ 混合使用 torch.nn.Parameter 和 Parameter,参数数量: {mixed_param_count}")
assert mixed_param_count == 2, f"应该有 2 个参数,实际为 {mixed_param_count}"
# 测试 8: 前向传播
class TestModuleWithForward(Module):
def __init__(self):
super().__init__()
self.weight = Parameter(torch.randn(5, 10))
self.bias = Parameter(torch.randn(5))
def forward(self, x):
return x @ self.weight.t() + self.bias
test_module_forward = TestModuleWithForward()
test_input = torch.randn(2, 10)
with torch.no_grad():
output = test_module_forward.forward(test_input)
print(f"✓ 前向传播成功,输出形状: {output.shape}")
assert output.shape == (2, 5), f"输出形状应该是 (2, 5),实际为 {output.shape}"
# 测试 9: 从 None 创建 # 测试 9: 从 None 创建
param_empty = Parameter(None) # param_empty = Parameter(None)
print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}") # print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}")
assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量" # assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量"
# 测试 10: 深拷贝 # 测试 10: 深拷贝
import copy # import copy
param_copy = copy.deepcopy(param1)
print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}")
assert param_copy.shape == param1.shape, "深拷贝后形状应该相同"
assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), "深拷贝应该是新对象"
print("\n=== 所有测试通过! ===") # param_copy = copy.deepcopy(param1)
# print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}")
# assert param_copy.shape == param1.shape, "深拷贝后形状应该相同"
# assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), (
# "深拷贝应该是新对象"
# )
print("\n=== 所有测试通过! ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment