Unverified Commit 3567f92e authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #567 from gongchensu/feature/add_nn_interface

 #567 - 增加 nn.Module, nn.Parameter,nn.ModuleList类
parents 0883d6ee ee722eb9
from infinicore.nn import functional
from infinicore.nn.modules import * # noqa: F403
from infinicore.nn.parameter import InfiniCoreParameter as Parameter
__all__ = ["functional"]
__all__ = ["functional", "Parameter"]
from .container import InfiniCoreModuleList as ModuleList
from .module import InfiniCoreModule as Module
__all__ = ["ModuleList", "Module"]
# ============================================
# Copyright (c) 2025, InfiniCore
#
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes.
import operator
from collections import OrderedDict
from itertools import chain
from typing import Iterator, List, Optional, Sequence, TypeVar, Union
from .module import InfiniCoreModule as Module
# Define type variable for module compatibility (supports InfiniCoreModule)
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
class InfiniCoreModuleList(Module):
r"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
InfiniCoreModule methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
>>> class MyModel(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.linears = InfiniCoreModuleList([
... torch.nn.Linear(10, 10) for i in range(10)
... ])
...
... def forward(self, x):
... # ModuleList can act as an iterable, or be indexed using ints
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x
"""
def __init__(self, modules: Optional[Sequence[ModuleType]] = None):
super().__init__()
if modules is not None:
self += modules
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f"index {idx} is out of range")
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(
self, idx: Union[int, slice]
) -> Union[ModuleType, "InfiniCoreModuleList"]:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
def __setitem__(self, idx: int, module: ModuleType) -> None:
idx = self._get_abs_string_index(idx)
# Use add_module to register module
self.add_module(idx, module)
def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
indices_to_delete = list(range(len(self._modules)))[idx]
for k in indices_to_delete:
if str(k) in self._modules:
del self._modules[str(k)]
else:
idx_str = self._get_abs_string_index(idx)
if idx_str in self._modules:
del self._modules[idx_str]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
if len(self._modules) > 0:
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[ModuleType]:
return iter(self._modules.values())
def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
return self.extend(modules)
def __add__(
self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"]
) -> "InfiniCoreModuleList":
r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args:
other (iterable): iterable of modules to concatenate
"""
if not isinstance(other, (list, tuple, InfiniCoreModuleList)):
raise TypeError(
f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f"got {type(other).__name__}"
)
combined = InfiniCoreModuleList()
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined
def append(self, module: ModuleType) -> "InfiniCoreModuleList":
r"""Append a given module to the end of the list.
Args:
module (InfiniCoreModule): module to append
"""
self.add_module(str(len(self)), module)
return self
def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
r"""Append modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, (list, tuple)):
try:
modules = list(modules)
except TypeError:
raise TypeError(
f"InfiniCoreModuleList.extend should be called with an "
f"iterable, but got {type(modules).__name__}"
)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
def insert(self, index: int, module: ModuleType) -> None:
r"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module ( InfiniCoreModule): module to insert
"""
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
def pop(self, idx: int = -1) -> ModuleType:
r"""Remove and return a module at the given index.
Args:
idx (int): index of the module to pop. Default: -1 (last module)
Returns:
Module: the module that was removed
"""
idx_str = self._get_abs_string_index(idx)
module = self._modules[idx_str]
# Use __delitem__ to ensure proper cleanup
self.__delitem__(int(idx_str))
return module
def __repr__(self) -> str:
"""Return a string representation of the ModuleList."""
if len(self) == 0:
return self.__class__.__name__ + "()"
lines = []
for i, module in enumerate(self):
lines.append(f"({i}): {repr(module)}")
main_str = self.__class__.__name__ + "(\n "
main_str += "\n ".join(lines) + "\n)"
return main_str
def __dir__(self) -> List[str]:
"""Return a list of attribute names, excluding numeric keys."""
keys = super().__dir__()
# Filter out numeric keys to avoid cluttering dir() output
keys = [key for key in keys if not key.isdigit()]
return keys
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Module`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework, custom
# parameter/buffer registration mechanisms, and simplified state_dict handling.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
import itertools
import warnings
from collections import OrderedDict, namedtuple
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
import infinicore
from ...tensor import Tensor
from ..parameter import InfiniCoreParameter as Parameter
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
T = TypeVar("T", bound="InfiniCoreModule")
class _IncompatibleKeys(
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])
):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
__str__ = __repr__
class InfiniCoreModule:
r"""Base class for InfiniCore neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure.
"""
_version: int = 1
_parameters: Dict[str, Optional[Parameter]]
_buffers: Dict[str, Optional[Tensor]]
_non_persistent_buffers_set: Set[str]
_modules: Dict[str, Optional["InfiniCoreModule"]]
def __init__(self):
super().__setattr__("_parameters", OrderedDict())
super().__setattr__("_buffers", OrderedDict())
super().__setattr__("_non_persistent_buffers_set", set())
super().__setattr__("_modules", OrderedDict())
def __getattr__(self, name: str) -> Any:
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
return _parameters[name]
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
return _buffers[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name: str, value: Union[Tensor, "InfiniCoreModule"]) -> None:
def remove_from(*dicts_or_sets) -> None:
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get("_parameters")
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call"
)
if isinstance(value, Parameter): # the value is of type Parameter
remove_from(
self.__dict__,
self._buffers,
self._modules,
self._non_persistent_buffers_set,
)
self.register_parameter(name, value)
elif name in params: # value will overwrite the name of params.
if not isinstance(value, Tensor):
raise TypeError(
f"cannot assign 'value' as parameter '{name}' (infinicore.nn.Parameter, Parameter or None expected)"
)
self.register_parameter(name, value)
else:
modules = self.__dict__.get("_modules")
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
)
if isinstance(value, InfiniCoreModule):
remove_from(
self.__dict__,
self._parameters,
self._buffers,
self._non_persistent_buffers_set,
)
modules[name] = value
elif name in modules: # Do not overwrite this variable
raise TypeError(
f"cannot assign 'value' as child module '{name}' (infinicore.nn.Module or None expected)"
)
else:
buffers = self.__dict__.get("_buffers")
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, Tensor):
raise TypeError(
f"cannot assign 'value' as buffer '{name}' "
"(torch.Tensor or None expected)"
)
buffers[name] = value
else:
super().__setattr__(name, value)
def __call__(self, *input, **kwargs):
return self.forward(*input, **kwargs)
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter.Buffers, by default, are persistent
and will be saved alongside parameters. This behavior can be changed
by setting :attr:`persistent` to ``False``. The only difference between
a persistent buffer and a non-persistent buffer is that the latter
will not be a part of this module's :attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
the buffer is **not** included in the module's :attr:`state_dict`.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
"""
if "_buffers" not in self.__dict__:
raise AttributeError("cannot assign buffer before Module.__init__() call")
elif not isinstance(name, str):
raise TypeError("buffer name should be a string. Got {}".format("name"))
elif "." in name:
raise KeyError('buffer name can\'t contain "."')
elif name == "":
raise KeyError('buffer name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, Tensor):
raise TypeError(
"cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)".format("tensor", name)
)
else:
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
def add_module(self, name: str, module: Optional["InfiniCoreModule"]) -> None:
r"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module or None): child module to be added to the module. If
``None``, then operations that run on modules, such as :attr:`eval`,
are ignored. If ``None``, the module is **not** included in the
module's :attr:`children`.
"""
if not isinstance(name, str):
raise TypeError(f"module name should be a string. Got {name}")
elif "." in name:
raise KeyError(f'module name can\'t contain ".", got: {name}')
elif name == "":
raise KeyError('module name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._modules:
raise KeyError(f"attribute '{name}' already exists")
if module is not None and not isinstance(module, InfiniCoreModule):
raise TypeError(f"{module} is not a Module subclass")
self._modules[name] = module
def register_parameter(self, name: str, param: Parameter) -> None:
r"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter or None): parameter to be added to the module. If
``None``, then operations that run on parameters, such as :attr:`cuda`,
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
"""
if "_parameters" not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call"
)
elif not isinstance(name, str):
raise TypeError("parameter name should be a string.")
elif "." in name:
raise KeyError('parameter name can\'t contain "."')
elif name == "":
raise KeyError('parameter name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._parameters:
raise KeyError(f"attribute '{name}' already exists")
if param is None:
self._parameters[name] = None # 竟然可以是None
else:
if not isinstance(param, (Parameter, Tensor)):
raise TypeError(
f"cannot assign 'param' object to parameter '{name}' "
"(infinicore.nn.Parameter, Parameter or None required)"
)
self._parameters[name] = param
super().__setattr__(name, param)
def get_extra_state(self) -> Any:
"""Return any extra state to include in the module's state_dict.
Implement this and a corresponding :func:`set_extra_state` for your module
if you need to store extra state. This function is called when building the
module's `state_dict()`.
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
Returns:
object: Any extra state to store in the module's state_dict
"""
raise RuntimeError(
"Reached a code path in Module.get_extra_state() that should never be called. "
)
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(self.__class__, "get_extra_state", InfiniCoreModule.get_extra_state)
is not InfiniCoreModule.get_extra_state
):
destination[extra_state_key] = self.get_extra_state()
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload
def state_dict(
self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...
) -> T_destination: ...
@overload
def state_dict(
self, *, prefix: str = ..., keep_vars: bool = ...
) -> Dict[str, Any]: ...
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# Also remove the logic for arg parsing together.
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
# DeprecationWarning is ignored by default
warnings.warn(
"Positional args are being deprecated, use kwargs instead. ",
FutureWarning,
stacklevel=2,
)
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == "":
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
return destination
def set_extra_state(self, state: Any):
"""
This function is called from :func:`load_state_dict` to handle any extra state
found within the `state_dict`. Implement this function and a corresponding
:func:`get_extra_state` for your module if you need to store extra state within its
`state_dict`.
Args:
state (dict): Extra state from the `state_dict`
"""
raise RuntimeError(
"Reached a code path in Module.set_extra_state() that should never be called. "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"to report this bug."
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
persistent_buffers = {
k: v
for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
# input_param must be of type infinicore.Tensor
if not isinstance(input_param, Tensor):
raise TypeError(
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
)
if (
(param.shape == input_param.shape)
and (param.dtype == input_param.dtype)
and (param.device == input_param.device)
):
param.copy_(input_param)
else:
print(f"param '{name}' don't match input_param '{key}'")
setattr(self, name, input_param)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(self.__class__, "set_extra_state", InfiniCoreModule.set_extra_state)
is not InfiniCoreModule.set_extra_state
):
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix) :].split(".", 1)
# Must be Module if it have attributes
if len(input_name) > 1:
if input_name[0] not in self._modules:
unexpected_keys.append(key)
elif input_name[0] not in local_state:
unexpected_keys.append(key)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
if not isinstance(state_dict, Mapping):
raise TypeError(
"Expected state_dict to be dict-like, got {}.".format(type(state_dict))
)
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
child_state_dict = {
k: v
for k, v in local_state_dict.items()
if k.startswith(child_prefix)
}
load(child, child_state_dict, child_prefix) # noqa: F821
load(self, state_dict)
del load
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0,
"Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
),
)
if len(missing_keys) > 0:
error_msgs.insert(
0,
"Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
),
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
return _IncompatibleKeys(missing_keys, unexpected_keys)
def parameters(self, recurse: bool = True) -> Iterator["Parameter"]:
r"""Returns an iterator over module parameters.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
... print(type(param), param.size())
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, "Parameter"]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
... if name in ['bias']:
... print(param.size())
"""
gen = self._named_members(
lambda module: module._parameters.items(), prefix=prefix, recurse=recurse
)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
r"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
... print(type(buf), buf.size())
"""
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
... if name in ['running_mean']:
... print(buf.size())
"""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
for k, v in module._buffers.items():
if v is None or v in memo:
continue
if k in module._non_persistent_buffers_set:
continue
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield (name, v)
def _named_members(self, get_members_fn, prefix="", recurse=True):
r"""Helper method to yield members with their names."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield (name, v)
def modules(self) -> Iterator["InfiniCoreModule"]:
r"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for name, module in self.named_modules():
yield module
def named_modules(
self,
memo: Optional[Set["InfiniCoreModule"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if memo is None:
memo = set()
if remove_duplicate:
if self in memo:
return
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
# Handle both InfiniCoreModule and torch.nn.Module
if isinstance(module, InfiniCoreModule):
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
elif isinstance(module, infinicore.nn.Module):
# For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples
for sub_name, sub_module in module.named_modules(
prefix=submodule_prefix, remove_duplicate=remove_duplicate
):
yield (sub_name, sub_module)
def children(self) -> Iterator["InfiniCoreModule"]:
r"""Returns an iterator over immediate children modules.
Yields:
Module: a child module (can be InfiniCoreModule or torch.nn.Module)
"""
for name, module in self.named_children():
yield module
def named_children(
self,
) -> Iterator[Tuple[str, "InfiniCoreModule"]]:
r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
"""
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
def eval(self: T) -> T:
r"""Sets the module in evaluation mode.
Returns:
Module: self
"""
pass
def _apply(self, fn, recurse=True):
raise KeyError("not support")
def to(self, *args, **kwargs):
raise KeyError("not support")
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
from ..tensor import Tensor
class InfiniCoreParameter(Tensor):
r"""A kind of Tensor that is to be considered a module parameter."""
def __init__(self, data=None):
if not isinstance(data, Tensor):
raise ValueError("The `data` variable must be of type `infinicore.Tensor`.")
super().__init__(data._underlying)
def __repr__(self):
return "Parameter containing:\n" + super().__repr__()
def __deepcopy__(self, memo):
raise ValueError("not supported!")
def __reduce_ex__(self, proto):
raise ValueError("not supported!")
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import os
import sys
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_convnet_with_param.safetensors")
import infinicore # noqa: E402
from infinicore.nn import Module # noqa: E402
# ============================================================
# 1. 定义模型
# ============================================================
device_str = "cuda"
class InfiniCoreNet(Module):
def __init__(self):
super().__init__()
self.a = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
self.b = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
def forward(self):
return infinicore.add(self.a, self.b)
infinicore_model_infer = InfiniCoreNet()
# ============================================================
# 2. 加载权重
# ============================================================
params_dict = {
"a": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
),
"b": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
),
}
infinicore_model_infer.load_state_dict(params_dict)
# ============================================================
# 3. 计算
# ============================================================
infinicore_model_out = infinicore_model_infer()
ref_out = infinicore.add(params_dict["a"], params_dict["b"])
# ============================================================
# 4. 对比结果
# ============================================================
print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ")
infinicore_model_out.debug()
ref_out.debug()
# ============================================================
# 5. to测试,buffer测试
# ============================================================
# 等待添加
import os
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import sys
import safetensors
import safetensors.torch
import torch
import torch.nn as nn
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
# 使用临时目录,如果不存在则自动创建
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_modulelist_with_param.safetensors")
def test():
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class TorchModuleListNet(nn.Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 torch.nn.ModuleList
self.layers = nn.ModuleList(
[
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
]
)
# 自定义 Parameter
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
# 应用自定义参数和 buffer
x = x * self.scale + self.offset
return x
# ===== 保存 Torch 模型 =====
torch_model = TorchModuleListNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
print("✓ PyTorch 模型已保存")
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer = TorchModuleListNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.rand(1, 3, 8, 8)
torch_model_out = torch_model_infer(input)
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from nn.modules import Module, ModuleList
class InfiniCoreModuleListNet(Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 ModuleList
self.layers = ModuleList(
[
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
]
)
# 保持与 Torch 模型一致的自定义参数和 buffer
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreModuleListNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
# ============================================================
# 4. 对比结果
# ============================================================
diff = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"✓ ModuleList 与 Torch 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ ModuleList 与 Torch 精度一致.")
else:
print("✗ ModuleList 与 Torch 精度存在差异.")
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print("\n=== 测试 ModuleList 基本功能 ===")
# 测试 1: 创建和访问
module_list = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
print(f"✓ 创建 ModuleList,长度: {len(module_list)}")
print(f"✓ 访问第一个模块: {type(module_list[0]).__name__}")
print(f"✓ 访问第二个模块: {type(module_list[1]).__name__}")
# 测试 2: append
module_list.append(nn.Softmax(dim=-1))
print(f"✓ append 后长度: {len(module_list)}")
# 测试 3: extend
module_list.extend([nn.Dropout(0.1), nn.Linear(5, 1)])
print(f"✓ extend 后长度: {len(module_list)}")
# 测试 4: 迭代
print("✓ 迭代 ModuleList:")
for i, module in enumerate(module_list):
print(f" [{i}] {type(module).__name__}")
# 测试 5: 索引访问
print(f"✓ 索引访问 module_list[0]: {type(module_list[0]).__name__}")
# 测试 6: state_dict
state_dict = module_list.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}")
print(f"✓ state_dict 包含模块参数: {any('0.' in k for k in state_dict.keys())}")
# 测试 7: 使用 ModuleList 的模型
class TestNet(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
test_model = TestNet()
test_input = torch.randn(2, 10)
test_output = test_model.forward(test_input)
print(f"✓ TestNet 输入形状: {test_input.shape}, 输出形状: {test_output.shape}")
# 测试 8: __add__ 方法
ml1 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
ml2 = ModuleList([nn.Linear(5, 3), nn.Sigmoid()])
ml3 = ml1 + ml2
print(f"✓ __add__ 方法测试: {len(ml1)} + {len(ml2)} = {len(ml3)}")
assert len(ml3) == 4, "合并后的长度应该为 4"
# 测试 9: pop 方法
ml4 = ModuleList([nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3)])
popped = ml4.pop()
print(
f"✓ pop 方法测试: 弹出后长度 {len(ml4)}, 弹出模块类型 {type(popped).__name__}"
)
assert len(ml4) == 2, "pop 后长度应该为 2"
assert isinstance(popped, nn.Linear), "弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
repr_str = repr(ml5)
print(f"✓ __repr__ 方法测试: 输出包含类名和模块信息")
assert "ModuleList" in repr_str or "InfiniCoreModuleList" in repr_str, (
"repr 应该包含类名"
)
assert "Linear" in repr_str, "repr 应该包含模块信息"
print(repr_str)
print("\n=== 所有测试通过! ===")
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print("\n=== 前向传播集成测试 ===")
# 使用 ModuleList 创建一个简单的模型
class TorchModuleListModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)]
)
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
class InfiniCoreModuleListModel(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# 创建模型
torch_model_forward = TorchModuleListModel()
infinicore_model_forward = InfiniCoreModuleListModel()
# 复制权重(确保初始权重一致)
infinicore_model_forward.load_state_dict(
torch_model_forward.state_dict(), strict=False
)
# 设置为评估模式
torch_model_forward.eval()
infinicore_model_forward.eval()
# 创建测试输入
test_input = torch.randn(2, 10)
# 前向传播
with torch.no_grad():
torch_output = torch_model_forward(test_input)
infinicore_output = infinicore_model_forward.forward(test_input)
# 对比结果
diff = (infinicore_output - torch_output).abs().max().item()
print(f"✓ 前向传播测试 - 输入形状: {test_input.shape}")
print(
f"✓ Torch 输出形状: {torch_output.shape}, 均值: {torch_output.detach().numpy().mean():.8f}"
)
print(
f"✓ InfiniCore 输出形状: {infinicore_output.shape}, 均值: {infinicore_output.detach().numpy().mean():.8f}"
)
print(f"✓ 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!")
else:
print("✗ 前向传播集成测试失败:存在差异")
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print("\n=== 混合模块兼容性测试 ===")
# 创建一个自定义的 InfiniCore 模块
class CustomLinear(Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
return x @ self.weight.t() + self.bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list = ModuleList(
[
nn.Linear(10, 5), # PyTorch 模块
CustomLinear(5, 3), # 自定义 InfiniCore 模块
nn.ReLU(), # PyTorch 模块
]
)
print(f"✓ 创建混合 ModuleList,长度: {len(mixed_list)}")
print(f"✓ 模块类型: {[type(m).__name__ for m in mixed_list]}")
# 测试参数注册
param_count = sum(1 for _ in mixed_list.parameters())
print(f"✓ 参数数量: {param_count}")
assert param_count == 4, (
f"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为 {param_count}"
)
# 测试 state_dict
mixed_state_dict = mixed_list.state_dict()
print(f"✓ state_dict 键数量: {len(mixed_state_dict)}")
assert len(mixed_state_dict) >= 4, "state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed = torch.randn(2, 10)
with torch.no_grad():
x = test_input_mixed
for module in mixed_list:
x = module.forward(x)
print(f"✓ 混合模块前向传播成功,输出形状: {x.shape}")
print("✓ 混合模块兼容性测试通过!")
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import os
import sys
import torch
import torch.nn as nn
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore"))
)
save_dir = os.path.join(os.path.dirname(__file__), "../../tmp")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors")
import infinicore # noqa: E402
from infinicore.nn import Module, Parameter # noqa: E402
device_str = "cuda"
class InfiniCoreParameterNet(Module):
def __init__(self):
super().__init__()
self.a = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device("cpu", 0)
)
)
def forward(self, x):
return infinicore.add(self.a, x)
infinicore_model_infer = InfiniCoreParameterNet()
# ============================================================
# 2. 加载权重
# ============================================================
params_dict = {
"a": infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
}
infinicore_model_infer.load_state_dict(params_dict)
# ============================================================
# 3. 计算
# ============================================================
x = infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
infinicore_model_out = infinicore_model_infer(x)
ref_out = infinicore.add(params_dict["a"], x)
# ============================================================
# 4. 对比结果
# ============================================================
print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ")
infinicore_model_out.debug()
ref_out.debug()
# ============================================================
# 5. 测试 Parameter 的基本功能
# ============================================================
print("\n=== 测试 Parameter 基本功能 ===")
# 测试 1: 创建 Parameter
param1 = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0)
)
)
print(f"✓ 创建 Parameter,形状: {param1.shape}")
# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名)
assert isinstance(param1, infinicore.nn.Parameter), "应该是 Parameter 类型"
assert isinstance(param1, infinicore.Tensor), "应该是 torch.Tensor 的子类"
# 测试 3: 自动注册到 Module
class TestModule(Module):
def __init__(self):
super().__init__()
self.weight = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
self.bias = infinicore.nn.Parameter(
infinicore.empty(
(1, 2, 3),
dtype=infinicore.float32,
device=infinicore.device(device_str),
)
)
test_module = TestModule()
param_count = sum(1 for _ in test_module.parameters())
print(f"✓ 自动注册到 Module,参数数量: {param_count}")
assert param_count == 2, f"应该有 2 个参数,实际为 {param_count}"
# 测试 4: 参数访问
assert test_module.weight is not None, "weight 应该可以访问"
assert test_module.bias is not None, "bias 应该可以访问"
print("✓ 参数可以通过属性访问")
# 测试 5: state_dict
state_dict = test_module.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}")
assert "weight" in state_dict, "state_dict 应该包含 weight"
assert "bias" in state_dict, "state_dict 应该包含 bias"
print(f"✓ state_dict 键: {list(state_dict.keys())}")
# 测试 6: __repr__
repr_str = repr(param1)
print(f"✓ __repr__ 方法: 输出包含类名")
assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名"
print(repr_str[:100] + "...")
# 测试 9: 从 None 创建
# param_empty = Parameter(None)
# print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}")
# assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量"
# 测试 10: 深拷贝
# import copy
# param_copy = copy.deepcopy(param1)
# print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}")
# assert param_copy.shape == param1.shape, "深拷贝后形状应该相同"
# assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), (
# "深拷贝应该是新对象"
# )
print("\n=== 所有测试通过! ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment