Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
27e57f3d
Commit
27e57f3d
authored
Nov 05, 2025
by
zhushuang
Committed by
zhuyue
Nov 17, 2025
Browse files
issue/507: add infinicore.nn.Module referencing torch.nn.Module and test case
parent
b7d9252b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
864 additions
and
0 deletions
+864
-0
python/infinicore/nn/modules/__init__.py
python/infinicore/nn/modules/__init__.py
+1
-0
python/infinicore/nn/modules/module.py
python/infinicore/nn/modules/module.py
+717
-0
test/infinicore/infinicore_nn_test.py
test/infinicore/infinicore_nn_test.py
+146
-0
No files found.
python/infinicore/nn/modules/__init__.py
0 → 100644
View file @
27e57f3d
from
.module
import
InfiniCoreModule
as
Module
python/infinicore/nn/modules/module.py
0 → 100644
View file @
27e57f3d
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Module`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework, custom
# parameter/buffer registration mechanisms, and simplified state_dict handling.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
from
collections
import
OrderedDict
,
namedtuple
import
itertools
import
warnings
import
torch
from
typing
import
Union
,
Tuple
,
Any
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
,
List
from
torch.utils._python_dispatch
import
is_traceable_wrapper_subclass
_EXTRA_STATE_KEY_SUFFIX
=
'_extra_state'
T
=
TypeVar
(
'T'
,
bound
=
'InfiniCoreModule'
)
class
_IncompatibleKeys
(
namedtuple
(
'IncompatibleKeys'
,
[
'missing_keys'
,
'unexpected_keys'
])):
def
__repr__
(
self
):
if
not
self
.
missing_keys
and
not
self
.
unexpected_keys
:
return
'<All keys matched successfully>'
return
super
().
__repr__
()
__str__
=
__repr__
class
InfiniCoreModule
:
r
"""Base class for InfiniCore neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing
to nest them in a tree structure.
"""
_version
:
int
=
1
training
:
bool
_parameters
:
Dict
[
str
,
Optional
[
torch
.
nn
.
Parameter
]]
_buffers
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
_non_persistent_buffers_set
:
Set
[
str
]
_modules
:
Dict
[
str
,
Optional
[
'InfiniCoreModule'
]]
def
__init__
(
self
):
super
().
__setattr__
(
"_parameters"
,
OrderedDict
())
super
().
__setattr__
(
"_buffers"
,
OrderedDict
())
super
().
__setattr__
(
"_non_persistent_buffers_set"
,
set
())
super
().
__setattr__
(
"_modules"
,
OrderedDict
())
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
if
"_parameters"
in
self
.
__dict__
:
_parameters
=
self
.
__dict__
[
"_parameters"
]
if
name
in
_parameters
:
return
_parameters
[
name
]
if
"_buffers"
in
self
.
__dict__
:
_buffers
=
self
.
__dict__
[
"_buffers"
]
if
name
in
_buffers
:
return
_buffers
[
name
]
if
"_modules"
in
self
.
__dict__
:
modules
=
self
.
__dict__
[
"_modules"
]
if
name
in
modules
:
return
modules
[
name
]
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
name
}
'"
)
def
__setattr__
(
self
,
name
:
str
,
value
:
Union
[
torch
.
Tensor
,
'InfiniCoreModule'
])
->
None
:
def
remove_from
(
*
dicts_or_sets
)
->
None
:
for
d
in
dicts_or_sets
:
if
name
in
d
:
if
isinstance
(
d
,
dict
):
del
d
[
name
]
else
:
d
.
discard
(
name
)
params
=
self
.
__dict__
.
get
(
"_parameters"
)
if
isinstance
(
value
,
torch
.
nn
.
Parameter
):
if
params
is
None
:
raise
AttributeError
(
"cannot assign parameters before Module.__init__() call"
)
remove_from
(
self
.
__dict__
,
self
.
_buffers
,
self
.
_modules
,
self
.
_non_persistent_buffers_set
,
)
self
.
register_parameter
(
name
,
value
)
elif
params
is
not
None
and
name
in
params
:
if
value
is
not
None
:
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as parameter '
{
name
}
' "
"(torch.nn.Parameter or None expected)"
)
self
.
register_parameter
(
name
,
value
)
else
:
modules
=
self
.
__dict__
.
get
(
"_modules"
)
if
isinstance
(
value
,
(
torch
.
nn
.
Module
)):
if
modules
is
None
:
raise
AttributeError
(
"cannot assign module before Module.__init__() call"
)
remove_from
(
self
.
__dict__
,
self
.
_parameters
,
self
.
_buffers
,
self
.
_non_persistent_buffers_set
,
)
modules
[
name
]
=
value
elif
modules
is
not
None
and
name
in
modules
:
if
value
is
not
None
:
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as child module '
{
name
}
' "
"(torch.nn.Module or None expected)"
)
modules
[
name
]
=
value
else
:
buffers
=
self
.
__dict__
.
get
(
"_buffers"
)
if
buffers
is
not
None
and
name
in
buffers
:
if
value
is
not
None
and
not
isinstance
(
value
,
torch
.
Tensor
):
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as buffer '
{
name
}
' "
"(torch.Tensor or None expected)"
)
buffers
[
name
]
=
value
else
:
super
().
__setattr__
(
name
,
value
)
def
register_buffer
(
self
,
name
:
str
,
tensor
:
Optional
[
torch
.
tensor
],
persistent
:
bool
=
True
)
->
None
:
r
"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter.Buffers, by default, are persistent
and will be saved alongside parameters. This behavior can be changed
by setting :attr:`persistent` to ``False``. The only difference between
a persistent buffer and a non-persistent buffer is that the latter
will not be a part of this module's :attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
the buffer is **not** included in the module's :attr:`state_dict`.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
"""
if
'_buffers'
not
in
self
.
__dict__
:
raise
AttributeError
(
"cannot assign buffer before Module.__init__() call"
)
elif
not
isinstance
(
name
,
str
):
raise
TypeError
(
"buffer name should be a string. "
"Got {}"
.
format
(
torch
.
typename
(
name
)))
elif
'.'
in
name
:
raise
KeyError
(
"buffer name can't contain
\"
.
\"
"
)
elif
name
==
''
:
raise
KeyError
(
"buffer name can't be empty string
\"\"
"
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_buffers
:
raise
KeyError
(
"attribute '{}' already exists"
.
format
(
name
))
elif
tensor
is
not
None
and
not
isinstance
(
tensor
,
torch
.
Tensor
):
raise
TypeError
(
"cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.
format
(
torch
.
typename
(
tensor
),
name
))
else
:
self
.
_buffers
[
name
]
=
tensor
if
persistent
:
self
.
_non_persistent_buffers_set
.
discard
(
name
)
else
:
self
.
_non_persistent_buffers_set
.
add
(
name
)
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter or None): parameter to be added to the module. If
``None``, then operations that run on parameters, such as :attr:`cuda`,
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
"""
if
"_parameters"
not
in
self
.
__dict__
:
raise
AttributeError
(
"cannot assign parameter before Module.__init__() call"
)
elif
not
isinstance
(
name
,
str
):
raise
TypeError
(
f
"parameter name should be a string. Got
{
torch
.
typename
(
name
)
}
"
)
elif
"."
in
name
:
raise
KeyError
(
'parameter name can
\'
t contain "."'
)
elif
name
==
""
:
raise
KeyError
(
'parameter name can
\'
t be empty string ""'
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_parameters
:
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
if
param
is
None
:
self
.
_parameters
[
name
]
=
None
elif
not
isinstance
(
param
,
torch
.
nn
.
Parameter
):
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
param
)
}
' object to parameter '
{
name
}
' "
"(torch.nn.Parameter or None required)"
)
else
:
self
.
_parameters
[
name
]
=
param
def
get_extra_state
(
self
)
->
Any
:
"""Return any extra state to include in the module's state_dict.
Implement this and a corresponding :func:`set_extra_state` for your module
if you need to store extra state. This function is called when building the
module's `state_dict()`.
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
Returns:
object: Any extra state to store in the module's state_dict
"""
raise
RuntimeError
(
"Reached a code path in Module.get_extra_state() that should never be called. "
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
r
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
destination
[
prefix
+
name
]
=
param
if
keep_vars
else
param
.
detach
()
for
name
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
getattr
(
self
.
__class__
,
"get_extra_state"
,
InfiniCoreModule
.
get_extra_state
)
is
not
InfiniCoreModule
.
get_extra_state
:
destination
[
extra_state_key
]
=
self
.
get_extra_state
()
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination
=
TypeVar
(
'T_destination'
,
bound
=
Dict
[
str
,
Any
])
@
overload
def
state_dict
(
self
,
*
,
destination
:
T_destination
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...)
->
T_destination
:
...
@
overload
def
state_dict
(
self
,
*
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...)
->
Dict
[
str
,
Any
]:
...
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# Also remove the logic for arg parsing together.
def
state_dict
(
self
,
*
args
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
r
"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if
len
(
args
)
>
0
:
# DeprecationWarning is ignored by default
warnings
.
warn
(
"Positional args are being deprecated, use kwargs instead. "
,
FutureWarning
,
stacklevel
=
2
,
)
if
destination
is
None
:
destination
=
args
[
0
]
if
len
(
args
)
>
1
and
prefix
==
''
:
prefix
=
args
[
1
]
if
len
(
args
)
>
2
and
keep_vars
is
False
:
keep_vars
=
args
[
2
]
if
destination
is
None
:
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
local_metadata
=
dict
(
version
=
self
.
_version
)
if
hasattr
(
destination
,
"_metadata"
):
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
module
.
state_dict
(
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
return
destination
def
set_extra_state
(
self
,
state
:
Any
):
"""
This function is called from :func:`load_state_dict` to handle any extra state
found within the `state_dict`. Implement this function and a corresponding
:func:`get_extra_state` for your module if you need to store extra state within its
`state_dict`.
Args:
state (dict): Extra state from the `state_dict`
"""
raise
RuntimeError
(
"Reached a code path in Module.set_extra_state() that should never be called. "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"to report this bug."
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
r
"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
persistent_buffers
=
{
k
:
v
for
k
,
v
in
self
.
_buffers
.
items
()
if
k
not
in
self
.
_non_persistent_buffers_set
}
local_name_params
=
itertools
.
chain
(
self
.
_parameters
.
items
(),
persistent_buffers
.
items
())
local_state
=
{
k
:
v
for
k
,
v
in
local_name_params
if
v
is
not
None
}
for
name
,
param
in
local_state
.
items
():
key
=
prefix
+
name
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
if
not
torch
.
overrides
.
is_tensor_like
(
input_param
):
error_msgs
.
append
(
'While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.
format
(
key
,
type
(
input_param
)))
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy
=
torch
.
nn
.
parameter
.
is_lazy
(
param
)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if
not
is_param_lazy
and
len
(
param
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
input_param
=
input_param
[
0
]
if
not
is_param_lazy
and
input_param
.
shape
!=
param
.
shape
:
# local shape should match the one in checkpoint
error_msgs
.
append
(
'size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.
format
(
key
,
input_param
.
shape
,
param
.
shape
))
continue
try
:
with
torch
.
no_grad
():
param
.
copy_
(
input_param
)
except
Exception
as
ex
:
error_msgs
.
append
(
'While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.
format
(
key
,
param
.
size
(),
input_param
.
size
(),
ex
.
args
))
elif
strict
:
missing_keys
.
append
(
key
)
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
getattr
(
self
.
__class__
,
"set_extra_state"
,
InfiniCoreModule
.
set_extra_state
)
is
not
InfiniCoreModule
.
set_extra_state
:
if
extra_state_key
in
state_dict
:
self
.
set_extra_state
(
state_dict
[
extra_state_key
])
elif
strict
:
missing_keys
.
append
(
extra_state_key
)
elif
strict
and
(
extra_state_key
in
state_dict
):
unexpected_keys
.
append
(
extra_state_key
)
if
strict
:
for
key
in
state_dict
.
keys
():
if
key
.
startswith
(
prefix
)
and
key
!=
extra_state_key
:
input_name
=
key
[
len
(
prefix
):].
split
(
"."
,
1
)
# Must be Module if it have attributes
if
len
(
input_name
)
>
1
:
if
input_name
[
0
]
not
in
self
.
_modules
:
unexpected_keys
.
append
(
key
)
elif
input_name
[
0
]
not
in
local_state
:
unexpected_keys
.
append
(
key
)
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
r
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
if
not
isinstance
(
state_dict
,
Mapping
):
raise
TypeError
(
"Expected state_dict to be dict-like, got {}."
.
format
(
type
(
state_dict
)))
missing_keys
:
List
[
str
]
=
[]
unexpected_keys
:
List
[
str
]
=
[]
error_msgs
:
List
[
str
]
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
OrderedDict
(
state_dict
)
if
metadata
is
not
None
:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict
.
_metadata
=
metadata
# type: ignore[attr-defined]
def
load
(
module
,
local_state_dict
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
local_state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
child_prefix
=
prefix
+
name
+
'.'
child_state_dict
=
{
k
:
v
for
k
,
v
in
local_state_dict
.
items
()
if
k
.
startswith
(
child_prefix
)}
load
(
child
,
child_state_dict
,
child_prefix
)
load
(
self
,
state_dict
)
del
load
if
strict
:
if
len
(
unexpected_keys
)
>
0
:
error_msgs
.
insert
(
0
,
'Unexpected key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
unexpected_keys
)))
if
len
(
missing_keys
)
>
0
:
error_msgs
.
insert
(
0
,
'Missing key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
)))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
def
children
(
self
)
->
Iterator
[
'InfiniCoreModule'
]:
r
"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
"""
for
name
,
module
in
self
.
named_children
():
yield
module
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
'InfiniCoreModule'
]]:
r
"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
"""
memo
=
set
()
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
and
module
not
in
memo
:
memo
.
add
(
module
)
yield
name
,
module
def
train
(
self
:
T
,
mode
:
bool
=
True
)
->
T
:
r
"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if
not
isinstance
(
mode
,
bool
):
raise
ValueError
(
"training mode is expected to be boolean"
)
self
.
training
=
mode
for
module
in
self
.
children
():
module
.
train
(
mode
)
return
self
def
eval
(
self
:
T
)
->
T
:
r
"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
"""
return
self
.
train
(
False
)
def
_apply
(
self
,
fn
,
recurse
=
True
):
if
recurse
:
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
compute_should_use_set_data
(
tensor
,
tensor_applied
):
if
torch
.
_has_compatible_shallow_copy_type
(
tensor
,
tensor_applied
):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return
not
torch
.
__future__
.
get_overwrite_module_params_on_conversion
()
else
:
return
False
should_use_swap_tensors
=
torch
.
__future__
.
get_swap_module_params_on_conversion
()
for
key
,
param
in
self
.
_parameters
.
items
():
if
param
is
None
:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with
torch
.
no_grad
():
param_applied
=
fn
(
param
)
p_should_use_set_data
=
compute_should_use_set_data
(
param
,
param_applied
)
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors
=
should_use_swap_tensors
or
is_traceable_wrapper_subclass
(
param_applied
)
param_grad
=
param
.
grad
if
p_should_use_swap_tensors
:
try
:
if
param_grad
is
not
None
:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param
.
grad
=
None
param_applied
=
torch
.
nn
.
Parameter
(
param_applied
,
requires_grad
=
param
.
requires_grad
)
torch
.
utils
.
swap_tensors
(
param
,
param_applied
)
except
Exception
as
e
:
if
param_grad
is
not
None
:
param
.
grad
=
param_grad
raise
RuntimeError
(
f
"_apply(): Couldn't swap
{
self
.
_get_name
()
}
.
{
key
}
"
)
from
e
out_param
=
param
elif
p_should_use_set_data
:
param
.
data
=
param_applied
out_param
=
param
else
:
assert
isinstance
(
param
,
torch
.
nn
.
Parameter
)
assert
param
.
is_leaf
out_param
=
torch
.
nn
.
Parameter
(
param_applied
,
param
.
requires_grad
)
self
.
_parameters
[
key
]
=
out_param
if
param_grad
is
not
None
:
with
torch
.
no_grad
():
grad_applied
=
fn
(
param_grad
)
g_should_use_set_data
=
compute_should_use_set_data
(
param_grad
,
grad_applied
)
if
p_should_use_swap_tensors
:
grad_applied
.
requires_grad_
(
param_grad
.
requires_grad
)
try
:
torch
.
utils
.
swap_tensors
(
param_grad
,
grad_applied
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"_apply(): Couldn't swap
{
self
.
_get_name
()
}
.
{
key
}
.grad"
)
from
e
out_param
.
grad
=
param_grad
elif
g_should_use_set_data
:
assert
out_param
.
grad
is
not
None
out_param
.
grad
.
data
=
grad_applied
else
:
assert
param_grad
.
is_leaf
out_param
.
grad
=
grad_applied
.
requires_grad_
(
param_grad
.
requires_grad
)
for
key
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
:
self
.
_buffers
[
key
]
=
fn
(
buf
)
return
self
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
dtype
is
not
None
:
if
not
(
dtype
.
is_floating_point
or
dtype
.
is_complex
):
raise
TypeError
(
'nn.Module.to only accepts floating point or complex '
f
'dtypes, but got desired dtype=
{
dtype
}
'
)
if
dtype
.
is_complex
:
warnings
.
warn
(
"Complex modules are a new feature under active development whose design may change, "
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
)
def
convert
(
t
):
try
:
if
convert_to_format
is
not
None
and
t
.
dim
()
in
(
4
,
5
):
return
t
.
to
(
device
,
dtype
if
t
.
is_floating_point
()
or
t
.
is_complex
()
else
None
,
non_blocking
,
memory_format
=
convert_to_format
,
)
return
t
.
to
(
device
,
dtype
if
t
.
is_floating_point
()
or
t
.
is_complex
()
else
None
,
non_blocking
,
)
except
NotImplementedError
as
e
:
if
str
(
e
)
==
"Cannot copy out of meta tensor; no data!"
:
raise
NotImplementedError
(
f
"
{
e
}
Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
f
"when moving module from meta to a different device."
)
from
None
else
:
raise
return
self
.
_apply
(
convert
)
test/infinicore/infinicore_nn_test.py
0 → 100644
View file @
27e57f3d
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_convnet_with_param.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型
# ============================================================
print
(
"===== 开始 CPU 一致性测试 ====="
)
class
TorchConvNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 主体网络
self
.
conv1
=
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv2
=
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn2
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv3
=
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
)
self
.
relu
=
nn
.
ReLU
()
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
# 注册一个 buffer
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
conv3
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchConvNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchConvNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
# ============================================================
# 3. 使用 infiniCore.nn.module 加载并推理
# ============================================================
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
from
nn
import
Module
class
InfiniCoreConvNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv2
=
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn2
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv3
=
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
)
self
.
relu
=
nn
.
ReLU
()
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
conv3
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreConvNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
# ============================================================
# 4. 对比结果
# ============================================================
diff_cpu
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"InfiniCoreModule 与 Torch (CPU) 最大误差:
{
diff_cpu
:.
6
e
}
"
)
if
diff_cpu
<
1e-9
:
print
(
"CPU 模式下 InfiniCore 与 Torch 输出完全一致."
)
else
:
print
(
"CPU 模式下输出存在差异."
)
# ============================================================
# 5. GPU 一致性测试(可选)
# ============================================================
if
torch
.
cuda
.
is_available
():
print
(
"
\n
===== 开始 GPU 一致性测试 ====="
)
# 将模型与输入都迁移到 GPU
torch_model_infer_gpu
=
TorchConvNet
().
to
(
"cuda"
)
torch_model_infer_gpu
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer_gpu
.
eval
()
infinicore_model_infer_gpu
=
InfiniCoreConvNet
().
to
(
"cuda"
)
infinicore_model_infer_gpu
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer_gpu
.
eval
()
# 生成 GPU 输入
input_gpu
=
input
.
to
(
"cuda"
)
# 分别前向推理
torch_out_gpu
=
torch_model_infer_gpu
(
input_gpu
)
infinicore_out_gpu
=
infinicore_model_infer_gpu
.
forward
(
input_gpu
)
# 结果比较
diff_gpu
=
(
infinicore_out_gpu
-
torch_out_gpu
).
abs
().
max
().
item
()
print
(
f
"InfiniCoreModule 与 Torch (GPU) 最大误差:
{
diff_gpu
:.
6
e
}
"
)
if
diff_gpu
<
1e-9
:
print
(
"GPU 模式下 InfiniCore 与 Torch 输出完全一致."
)
else
:
print
(
"GPU 模式下输出存在差异."
)
else
:
print
(
"
\n
未检测到 GPU,跳过 GPU 一致性测试。"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment