Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ee722eb9
Commit
ee722eb9
authored
Nov 13, 2025
by
pengcheng888
Committed by
zhuyue
Nov 17, 2025
Browse files
issue/567-只处理infinicore.Tensor,能够加载infinicore.Tensor的权重,修改了module.py paramter.py部分代码
parent
f6107946
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
851 additions
and
952 deletions
+851
-952
python/infinicore/nn/__init__.py
python/infinicore/nn/__init__.py
+3
-1
python/infinicore/nn/modules/__init__.py
python/infinicore/nn/modules/__init__.py
+3
-2
python/infinicore/nn/modules/container.py
python/infinicore/nn/modules/container.py
+25
-20
python/infinicore/nn/modules/module.py
python/infinicore/nn/modules/module.py
+235
-333
python/infinicore/nn/modules/parameter.py
python/infinicore/nn/modules/parameter.py
+0
-133
python/infinicore/nn/parameter.py
python/infinicore/nn/parameter.py
+34
-0
test/infinicore/infinicore_module_list_test.py
test/infinicore/infinicore_module_list_test.py
+0
-317
test/infinicore/infinicore_nn_test.py
test/infinicore/infinicore_nn_test.py
+0
-146
test/infinicore/nn/Module.py
test/infinicore/nn/Module.py
+80
-0
test/infinicore/nn/ModuleList.py
test/infinicore/nn/ModuleList.py
+323
-0
test/infinicore/nn/Parameter.py
test/infinicore/nn/Parameter.py
+148
-0
No files found.
python/infinicore/nn/__init__.py
View file @
ee722eb9
from
infinicore.nn
import
functional
from
infinicore.nn
import
functional
from
infinicore.nn.modules
import
*
# noqa: F403
from
infinicore.nn.parameter
import
InfiniCoreParameter
as
Parameter
__all__
=
[
"functional"
]
__all__
=
[
"functional"
,
"Parameter"
]
python/infinicore/nn/modules/__init__.py
View file @
ee722eb9
from
.container
import
InfiniCoreModuleList
as
ModuleList
from
.module
import
InfiniCoreModule
as
Module
from
.module
import
InfiniCoreModule
as
Module
from
.module_list
import
InfiniCoreModuleList
as
ModuleList
from
.parameter
import
InfiniCoreParameter
as
Parameter
__all__
=
[
"ModuleList"
,
"Module"
]
python/infinicore/nn/modules/
module_list
.py
→
python/infinicore/nn/modules/
container
.py
View file @
ee722eb9
# ============================================
# Copyright (c) 2025, InfiniCore
# Copyright (c) 2025, InfiniCore
#
#
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes.
# but based on InfiniCoreModule for inference purposes.
from
typing
import
List
,
Optional
,
Iterator
,
Union
,
Sequence
,
TypeVar
import
torch
import
operator
import
operator
from
itertools
import
chain
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
.module
import
InfiniCoreModule
from
itertools
import
chain
from
typing
import
Iterator
,
List
,
Optional
,
Sequence
,
TypeVar
,
Union
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
from
.module
import
InfiniCoreModule
as
Module
ModuleType
=
TypeVar
(
'ModuleType'
,
bound
=
Union
[
torch
.
nn
.
Module
,
'InfiniCoreModule'
])
# Define type variable for module compatibility (supports InfiniCoreModule)
ModuleType
=
TypeVar
(
"ModuleType"
,
bound
=
Union
[
"Module"
])
class
InfiniCoreModuleList
(
InfiniCoreModule
):
class
InfiniCoreModuleList
(
Module
):
r
"""Holds submodules in a list.
r
"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but
InfiniCoreModuleList can be indexed like a regular Python list, but
...
@@ -54,7 +55,9 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -54,7 +55,9 @@ class InfiniCoreModuleList(InfiniCoreModule):
idx
+=
len
(
self
)
idx
+=
len
(
self
)
return
str
(
idx
)
return
str
(
idx
)
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
Union
[
ModuleType
,
'InfiniCoreModuleList'
]:
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
]
)
->
Union
[
ModuleType
,
"InfiniCoreModuleList"
]:
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
):
return
self
.
__class__
(
list
(
self
.
_modules
.
values
())[
idx
])
return
self
.
__class__
(
list
(
self
.
_modules
.
values
())[
idx
])
else
:
else
:
...
@@ -75,7 +78,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -75,7 +78,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
idx_str
=
self
.
_get_abs_string_index
(
idx
)
idx_str
=
self
.
_get_abs_string_index
(
idx
)
if
idx_str
in
self
.
_modules
:
if
idx_str
in
self
.
_modules
:
del
self
.
_modules
[
idx_str
]
del
self
.
_modules
[
idx_str
]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
# To preserve numbering, self._modules is being reconstructed with modules after deletion
if
len
(
self
.
_modules
)
>
0
:
if
len
(
self
.
_modules
)
>
0
:
str_indices
=
[
str
(
i
)
for
i
in
range
(
len
(
self
.
_modules
))]
str_indices
=
[
str
(
i
)
for
i
in
range
(
len
(
self
.
_modules
))]
...
@@ -87,10 +90,12 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -87,10 +90,12 @@ class InfiniCoreModuleList(InfiniCoreModule):
def
__iter__
(
self
)
->
Iterator
[
ModuleType
]:
def
__iter__
(
self
)
->
Iterator
[
ModuleType
]:
return
iter
(
self
.
_modules
.
values
())
return
iter
(
self
.
_modules
.
values
())
def
__iadd__
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'
InfiniCoreModuleList
'
:
def
__iadd__
(
self
,
modules
:
Sequence
[
ModuleType
])
->
"
InfiniCoreModuleList
"
:
return
self
.
extend
(
modules
)
return
self
.
extend
(
modules
)
def
__add__
(
self
,
other
:
Union
[
Sequence
[
ModuleType
],
'InfiniCoreModuleList'
])
->
'InfiniCoreModuleList'
:
def
__add__
(
self
,
other
:
Union
[
Sequence
[
ModuleType
],
"InfiniCoreModuleList"
]
)
->
"InfiniCoreModuleList"
:
r
"""Return a new InfiniCoreModuleList by concatenating with another iterable.
r
"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args:
Args:
...
@@ -101,22 +106,22 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -101,22 +106,22 @@ class InfiniCoreModuleList(InfiniCoreModule):
f
"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f
"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f
"got
{
type
(
other
).
__name__
}
"
f
"got
{
type
(
other
).
__name__
}
"
)
)
combined
=
InfiniCoreModuleList
()
combined
=
InfiniCoreModuleList
()
for
i
,
module
in
enumerate
(
chain
(
self
,
other
)):
for
i
,
module
in
enumerate
(
chain
(
self
,
other
)):
combined
.
add_module
(
str
(
i
),
module
)
combined
.
add_module
(
str
(
i
),
module
)
return
combined
return
combined
def
append
(
self
,
module
:
ModuleType
)
->
'
InfiniCoreModuleList
'
:
def
append
(
self
,
module
:
ModuleType
)
->
"
InfiniCoreModuleList
"
:
r
"""Append a given module to the end of the list.
r
"""Append a given module to the end of the list.
Args:
Args:
module (
nn.Module or
InfiniCoreModule): module to append
module (InfiniCoreModule): module to append
"""
"""
self
.
add_module
(
str
(
len
(
self
)),
module
)
self
.
add_module
(
str
(
len
(
self
)),
module
)
return
self
return
self
def
extend
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'
InfiniCoreModuleList
'
:
def
extend
(
self
,
modules
:
Sequence
[
ModuleType
])
->
"
InfiniCoreModuleList
"
:
r
"""Append modules from a Python iterable to the end of the list.
r
"""Append modules from a Python iterable to the end of the list.
Args:
Args:
...
@@ -130,7 +135,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -130,7 +135,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
f
"InfiniCoreModuleList.extend should be called with an "
f
"InfiniCoreModuleList.extend should be called with an "
f
"iterable, but got
{
type
(
modules
).
__name__
}
"
f
"iterable, but got
{
type
(
modules
).
__name__
}
"
)
)
offset
=
len
(
self
)
offset
=
len
(
self
)
for
i
,
module
in
enumerate
(
modules
):
for
i
,
module
in
enumerate
(
modules
):
self
.
add_module
(
str
(
offset
+
i
),
module
)
self
.
add_module
(
str
(
offset
+
i
),
module
)
...
@@ -141,7 +146,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -141,7 +146,7 @@ class InfiniCoreModuleList(InfiniCoreModule):
Args:
Args:
index (int): index to insert.
index (int): index to insert.
module (
nn.Module or
InfiniCoreModule): module to insert
module ( InfiniCoreModule): module to insert
"""
"""
for
i
in
range
(
len
(
self
.
_modules
),
index
,
-
1
):
for
i
in
range
(
len
(
self
.
_modules
),
index
,
-
1
):
self
.
_modules
[
str
(
i
)]
=
self
.
_modules
[
str
(
i
-
1
)]
self
.
_modules
[
str
(
i
)]
=
self
.
_modules
[
str
(
i
-
1
)]
...
@@ -166,11 +171,11 @@ class InfiniCoreModuleList(InfiniCoreModule):
...
@@ -166,11 +171,11 @@ class InfiniCoreModuleList(InfiniCoreModule):
"""Return a string representation of the ModuleList."""
"""Return a string representation of the ModuleList."""
if
len
(
self
)
==
0
:
if
len
(
self
)
==
0
:
return
self
.
__class__
.
__name__
+
"()"
return
self
.
__class__
.
__name__
+
"()"
lines
=
[]
lines
=
[]
for
i
,
module
in
enumerate
(
self
):
for
i
,
module
in
enumerate
(
self
):
lines
.
append
(
f
"(
{
i
}
):
{
repr
(
module
)
}
"
)
lines
.
append
(
f
"(
{
i
}
):
{
repr
(
module
)
}
"
)
main_str
=
self
.
__class__
.
__name__
+
"(
\n
"
main_str
=
self
.
__class__
.
__name__
+
"(
\n
"
main_str
+=
"
\n
"
.
join
(
lines
)
+
"
\n
)"
main_str
+=
"
\n
"
.
join
(
lines
)
+
"
\n
)"
return
main_str
return
main_str
...
...
python/infinicore/nn/modules/module.py
View file @
ee722eb9
# Copyright (c) 2025, InfiniCore
# Copyright (c) 2025, InfiniCore
#
#
# This file contains modified code derived from PyTorch's `torch.nn.Module`
# This file contains modified code derived from PyTorch's `torch.nn.Module`
# implementation, which is licensed under the BSD 3-Clause License.
# implementation, which is licensed under the BSD 3-Clause License.
#
#
...
@@ -13,27 +13,38 @@
...
@@ -13,27 +13,38 @@
#
#
# The use of this file is governed by the BSD 3-Clause License.
# The use of this file is governed by the BSD 3-Clause License.
from
collections
import
OrderedDict
,
namedtuple
import
itertools
import
itertools
import
warnings
import
warnings
from
typing
import
TYPE_CHECKING
from
collections
import
OrderedDict
,
namedtuple
from
typing
import
(
import
torch
Any
,
Dict
,
from
typing
import
Union
,
Tuple
,
Any
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
,
List
Iterator
,
from
torch.utils._python_dispatch
import
is_traceable_wrapper_subclass
List
,
Mapping
,
if
TYPE_CHECKING
:
Optional
,
from
.parameter
import
InfiniCoreParameter
as
Parameter
Set
,
Tuple
,
_EXTRA_STATE_KEY_SUFFIX
=
'_extra_state'
TypeVar
,
Union
,
T
=
TypeVar
(
'T'
,
bound
=
'InfiniCoreModule'
)
overload
,
)
class
_IncompatibleKeys
(
namedtuple
(
'IncompatibleKeys'
,
[
'missing_keys'
,
'unexpected_keys'
])):
import
infinicore
from
...tensor
import
Tensor
from
..parameter
import
InfiniCoreParameter
as
Parameter
_EXTRA_STATE_KEY_SUFFIX
=
"_extra_state"
T
=
TypeVar
(
"T"
,
bound
=
"InfiniCoreModule"
)
class
_IncompatibleKeys
(
namedtuple
(
"IncompatibleKeys"
,
[
"missing_keys"
,
"unexpected_keys"
])
):
def
__repr__
(
self
):
def
__repr__
(
self
):
if
not
self
.
missing_keys
and
not
self
.
unexpected_keys
:
if
not
self
.
missing_keys
and
not
self
.
unexpected_keys
:
return
'
<All keys matched successfully>
'
return
"
<All keys matched successfully>
"
return
super
().
__repr__
()
return
super
().
__repr__
()
__str__
=
__repr__
__str__
=
__repr__
...
@@ -42,18 +53,14 @@ class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpec
...
@@ -42,18 +53,14 @@ class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpec
class
InfiniCoreModule
:
class
InfiniCoreModule
:
r
"""Base class for InfiniCore neural network modules.
r
"""Base class for InfiniCore neural network modules.
Your models should also subclass this class.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure.
Modules can also contain other Modules, allowing
to nest them in a tree structure.
"""
"""
_version
:
int
=
1
_version
:
int
=
1
_parameters
:
Dict
[
str
,
Optional
[
Parameter
]]
training
:
bool
_buffers
:
Dict
[
str
,
Optional
[
Tensor
]]
_parameters
:
Dict
[
str
,
Optional
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]]]
_buffers
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
_non_persistent_buffers_set
:
Set
[
str
]
_non_persistent_buffers_set
:
Set
[
str
]
_modules
:
Dict
[
str
,
Optional
[
'
InfiniCoreModule
'
]]
_modules
:
Dict
[
str
,
Optional
[
"
InfiniCoreModule
"
]]
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__setattr__
(
"_parameters"
,
OrderedDict
())
super
().
__setattr__
(
"_parameters"
,
OrderedDict
())
...
@@ -66,19 +73,22 @@ class InfiniCoreModule:
...
@@ -66,19 +73,22 @@ class InfiniCoreModule:
_parameters
=
self
.
__dict__
[
"_parameters"
]
_parameters
=
self
.
__dict__
[
"_parameters"
]
if
name
in
_parameters
:
if
name
in
_parameters
:
return
_parameters
[
name
]
return
_parameters
[
name
]
if
"_buffers"
in
self
.
__dict__
:
if
"_buffers"
in
self
.
__dict__
:
_buffers
=
self
.
__dict__
[
"_buffers"
]
_buffers
=
self
.
__dict__
[
"_buffers"
]
if
name
in
_buffers
:
if
name
in
_buffers
:
return
_buffers
[
name
]
return
_buffers
[
name
]
if
"_modules"
in
self
.
__dict__
:
if
"_modules"
in
self
.
__dict__
:
modules
=
self
.
__dict__
[
"_modules"
]
modules
=
self
.
__dict__
[
"_modules"
]
if
name
in
modules
:
if
name
in
modules
:
return
modules
[
name
]
return
modules
[
name
]
raise
AttributeError
(
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
name
}
'"
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
name
}
'"
)
)
def
__setattr__
(
self
,
name
:
str
,
value
:
Union
[
torch
.
Tensor
,
'
InfiniCoreModule
'
])
->
None
:
def
__setattr__
(
self
,
name
:
str
,
value
:
Union
[
Tensor
,
"
InfiniCoreModule
"
])
->
None
:
def
remove_from
(
*
dicts_or_sets
)
->
None
:
def
remove_from
(
*
dicts_or_sets
)
->
None
:
for
d
in
dicts_or_sets
:
for
d
in
dicts_or_sets
:
if
name
in
d
:
if
name
in
d
:
...
@@ -88,13 +98,12 @@ class InfiniCoreModule:
...
@@ -88,13 +98,12 @@ class InfiniCoreModule:
d
.
discard
(
name
)
d
.
discard
(
name
)
params
=
self
.
__dict__
.
get
(
"_parameters"
)
params
=
self
.
__dict__
.
get
(
"_parameters"
)
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
if
params
is
None
:
from
.parameter
import
InfiniCoreParameter
as
Parameter
raise
AttributeError
(
if
isinstance
(
value
,
(
torch
.
nn
.
Parameter
,
Parameter
)):
"cannot assign parameters before Module.__init__() call"
if
params
is
None
:
)
raise
AttributeError
(
"cannot assign parameters before Module.__init__() call"
if
isinstance
(
value
,
Parameter
):
# the value is of type Parameter
)
remove_from
(
remove_from
(
self
.
__dict__
,
self
.
__dict__
,
self
.
_buffers
,
self
.
_buffers
,
...
@@ -102,20 +111,21 @@ class InfiniCoreModule:
...
@@ -102,20 +111,21 @@ class InfiniCoreModule:
self
.
_non_persistent_buffers_set
,
self
.
_non_persistent_buffers_set
,
)
)
self
.
register_parameter
(
name
,
value
)
self
.
register_parameter
(
name
,
value
)
elif
params
is
not
None
and
name
in
params
:
elif
name
in
params
:
# value will overwrite the
name
of
params
.
if
value
is
not
None
:
if
not
isinstance
(
value
,
Tensor
)
:
raise
TypeError
(
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as parameter '
{
name
}
' "
f
"cannot assign 'value' as parameter '
{
name
}
' (infinicore.nn.Parameter, Parameter or None expected)"
"(torch.nn.Parameter, Parameter or None expected)"
)
)
self
.
register_parameter
(
name
,
value
)
self
.
register_parameter
(
name
,
value
)
else
:
else
:
modules
=
self
.
__dict__
.
get
(
"_modules"
)
modules
=
self
.
__dict__
.
get
(
"_modules"
)
if
isinstance
(
value
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)):
if
modules
is
None
:
if
modules
is
None
:
raise
AttributeError
(
raise
AttributeError
(
"cannot assign module before Module.__init__() call"
"cannot assign module before Module.__init__() call"
)
)
if
isinstance
(
value
,
InfiniCoreModule
):
remove_from
(
remove_from
(
self
.
__dict__
,
self
.
__dict__
,
self
.
_parameters
,
self
.
_parameters
,
...
@@ -123,32 +133,35 @@ class InfiniCoreModule:
...
@@ -123,32 +133,35 @@ class InfiniCoreModule:
self
.
_non_persistent_buffers_set
,
self
.
_non_persistent_buffers_set
,
)
)
modules
[
name
]
=
value
modules
[
name
]
=
value
elif
modules
is
not
None
and
name
in
modules
:
elif
name
in
modules
:
# Do not overwrite this variable
if
value
is
not
None
:
raise
TypeError
(
raise
TypeError
(
f
"cannot assign 'value' as child module '
{
name
}
' (infinicore.nn.Module or None expected)"
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as child module '
{
name
}
' "
)
"(torch.nn.Module or None expected)"
)
modules
[
name
]
=
value
else
:
else
:
buffers
=
self
.
__dict__
.
get
(
"_buffers"
)
buffers
=
self
.
__dict__
.
get
(
"_buffers"
)
if
buffers
is
not
None
and
name
in
buffers
:
if
buffers
is
not
None
and
name
in
buffers
:
if
value
is
not
None
and
not
isinstance
(
value
,
torch
.
Tensor
):
if
value
is
not
None
and
not
isinstance
(
value
,
Tensor
):
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as buffer '
{
name
}
' "
raise
TypeError
(
"(torch.Tensor or None expected)"
f
"cannot assign 'value' as buffer '
{
name
}
' "
)
"(torch.Tensor or None expected)"
)
buffers
[
name
]
=
value
buffers
[
name
]
=
value
else
:
else
:
super
().
__setattr__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
def
register_buffer
(
self
,
name
:
str
,
tensor
:
Optional
[
torch
.
tensor
],
persistent
:
bool
=
True
)
->
None
:
def
__call__
(
self
,
*
input
,
**
kwargs
):
return
self
.
forward
(
*
input
,
**
kwargs
)
def
register_buffer
(
self
,
name
:
str
,
tensor
:
Optional
[
Tensor
],
persistent
:
bool
=
True
)
->
None
:
r
"""Adds a buffer to the module.
r
"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
This is typically used to register a buffer that should not to be
considered a model parameter.Buffers, by default, are persistent
considered a model parameter.Buffers, by default, are persistent
and will be saved alongside parameters. This behavior can be changed
and will be saved alongside parameters. This behavior can be changed
by setting :attr:`persistent` to ``False``. The only difference between
by setting :attr:`persistent` to ``False``. The only difference between
a persistent buffer and a non-persistent buffer is that the latter
a persistent buffer and a non-persistent buffer is that the latter
will not be a part of this module's :attr:`state_dict`.
will not be a part of this module's :attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Buffers can be accessed as attributes using given names.
...
@@ -163,22 +176,21 @@ class InfiniCoreModule:
...
@@ -163,22 +176,21 @@ class InfiniCoreModule:
:attr:`state_dict`.
:attr:`state_dict`.
"""
"""
if
'_buffers'
not
in
self
.
__dict__
:
if
"_buffers"
not
in
self
.
__dict__
:
raise
AttributeError
(
raise
AttributeError
(
"cannot assign buffer before Module.__init__() call"
)
"cannot assign buffer before Module.__init__() call"
)
elif
not
isinstance
(
name
,
str
):
elif
not
isinstance
(
name
,
str
):
raise
TypeError
(
"buffer name should be a string. "
raise
TypeError
(
"buffer name should be a string. Got {}"
.
format
(
"name"
))
"Got {}"
.
format
(
torch
.
typename
(
name
)))
elif
"."
in
name
:
elif
'.'
in
name
:
raise
KeyError
(
'buffer name can
\'
t contain "."'
)
raise
KeyError
(
"buffer name can't contain
\"
.
\"
"
)
elif
name
==
""
:
elif
name
==
''
:
raise
KeyError
(
'buffer name can
\'
t be empty string ""'
)
raise
KeyError
(
"buffer name can't be empty string
\"\"
"
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_buffers
:
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_buffers
:
raise
KeyError
(
"attribute '{}' already exists"
.
format
(
name
))
raise
KeyError
(
"attribute '{}' already exists"
.
format
(
name
))
elif
tensor
is
not
None
and
not
isinstance
(
tensor
,
torch
.
Tensor
):
elif
tensor
is
not
None
and
not
isinstance
(
tensor
,
Tensor
):
raise
TypeError
(
"cannot assign '{}' object to buffer '{}' "
raise
TypeError
(
"(torch Tensor or None required)"
"cannot assign '{}' object to buffer '{}' "
.
format
(
torch
.
typename
(
tensor
),
name
))
"(torch Tensor or None required)"
.
format
(
"tensor"
,
name
)
)
else
:
else
:
self
.
_buffers
[
name
]
=
tensor
self
.
_buffers
[
name
]
=
tensor
if
persistent
:
if
persistent
:
...
@@ -186,8 +198,7 @@ class InfiniCoreModule:
...
@@ -186,8 +198,7 @@ class InfiniCoreModule:
else
:
else
:
self
.
_non_persistent_buffers_set
.
add
(
name
)
self
.
_non_persistent_buffers_set
.
add
(
name
)
def
add_module
(
self
,
name
:
str
,
module
:
Optional
[
"InfiniCoreModule"
])
->
None
:
def
add_module
(
self
,
name
:
str
,
module
:
Optional
[
torch
.
nn
.
Module
])
->
None
:
r
"""Add a child module to the current module.
r
"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
The module can be accessed as an attribute using the given name.
...
@@ -201,20 +212,20 @@ class InfiniCoreModule:
...
@@ -201,20 +212,20 @@ class InfiniCoreModule:
module's :attr:`children`.
module's :attr:`children`.
"""
"""
if
not
isinstance
(
name
,
str
):
if
not
isinstance
(
name
,
str
):
raise
TypeError
(
f
"module name should be a string. Got
{
torch
.
typename
(
name
)
}
"
)
raise
TypeError
(
f
"module name should be a string. Got
{
name
}
"
)
elif
'.'
in
name
:
elif
"."
in
name
:
raise
KeyError
(
f
"
module name can't contain
\
"
.
\
"
, got:
{
name
}
"
)
raise
KeyError
(
f
'
module name can
\
'
t contain ".", got:
{
name
}
'
)
elif
name
==
''
:
elif
name
==
""
:
raise
KeyError
(
"
module name can't be empty string
\"\
"
"
)
raise
KeyError
(
'
module name can
\
'
t be empty string ""
'
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_modules
:
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_modules
:
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
if
module
is
not
None
and
not
isinstance
(
module
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)
)
:
if
module
is
not
None
and
not
isinstance
(
module
,
InfiniCoreModule
):
raise
TypeError
(
f
"
{
torch
.
typename
(
module
)
}
is not a Module subclass"
)
raise
TypeError
(
f
"
{
module
}
is not a Module subclass"
)
self
.
_modules
[
name
]
=
module
self
.
_modules
[
name
]
=
module
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]]
)
->
None
:
def
register_parameter
(
self
,
name
:
str
,
param
:
Parameter
)
->
None
:
r
"""Add a parameter to the module.
r
"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
The parameter can be accessed as an attribute using given name.
...
@@ -227,15 +238,13 @@ class InfiniCoreModule:
...
@@ -227,15 +238,13 @@ class InfiniCoreModule:
are ignored. If ``None``, the parameter is **not** included in the
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
module's :attr:`state_dict`.
"""
"""
if
"_parameters"
not
in
self
.
__dict__
:
if
"_parameters"
not
in
self
.
__dict__
:
raise
AttributeError
(
raise
AttributeError
(
"cannot assign parameter before Module.__init__() call"
"cannot assign parameter before Module.__init__() call"
)
)
elif
not
isinstance
(
name
,
str
):
elif
not
isinstance
(
name
,
str
):
raise
TypeError
(
raise
TypeError
(
"parameter name should be a string."
)
f
"parameter name should be a string. Got
{
torch
.
typename
(
name
)
}
"
)
elif
"."
in
name
:
elif
"."
in
name
:
raise
KeyError
(
'parameter name can
\'
t contain "."'
)
raise
KeyError
(
'parameter name can
\'
t contain "."'
)
elif
name
==
""
:
elif
name
==
""
:
...
@@ -244,16 +253,16 @@ class InfiniCoreModule:
...
@@ -244,16 +253,16 @@ class InfiniCoreModule:
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
if
param
is
None
:
if
param
is
None
:
self
.
_parameters
[
name
]
=
None
self
.
_parameters
[
name
]
=
None
# 竟然可以是None
else
:
else
:
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
if
not
isinstance
(
param
,
(
Parameter
,
Tensor
)):
from
.parameter
import
InfiniCoreParameter
as
Parameter
if
not
isinstance
(
param
,
(
torch
.
nn
.
Parameter
,
Parameter
)):
raise
TypeError
(
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
param
)
}
' object to parameter '
{
name
}
' "
f
"cannot assign
'param' object to parameter '
{
name
}
' "
"(
torch
.nn.Parameter, Parameter or None required)"
"(
infinicore
.nn.Parameter, Parameter or None required)"
)
)
self
.
_parameters
[
name
]
=
param
self
.
_parameters
[
name
]
=
param
super
().
__setattr__
(
name
,
param
)
def
get_extra_state
(
self
)
->
Any
:
def
get_extra_state
(
self
)
->
Any
:
"""Return any extra state to include in the module's state_dict.
"""Return any extra state to include in the module's state_dict.
...
@@ -272,7 +281,7 @@ class InfiniCoreModule:
...
@@ -272,7 +281,7 @@ class InfiniCoreModule:
"""
"""
raise
RuntimeError
(
raise
RuntimeError
(
"Reached a code path in Module.get_extra_state() that should never be called. "
"Reached a code path in Module.get_extra_state() that should never be called. "
)
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
r
"""Saves module state to `destination` dictionary, containing a state
r
"""Saves module state to `destination` dictionary, containing a state
...
@@ -289,29 +298,34 @@ class InfiniCoreModule:
...
@@ -289,29 +298,34 @@ class InfiniCoreModule:
"""
"""
for
name
,
param
in
self
.
_parameters
.
items
():
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
if
param
is
not
None
:
destination
[
prefix
+
name
]
=
param
if
keep_vars
else
param
.
detach
()
destination
[
prefix
+
name
]
=
param
if
keep_vars
else
param
for
name
,
buf
in
self
.
_buffers
.
items
():
for
name
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
getattr
(
self
.
__class__
,
"get_extra_state"
,
InfiniCoreModule
.
get_extra_state
)
is
not
InfiniCoreModule
.
get_extra_state
:
if
(
getattr
(
self
.
__class__
,
"get_extra_state"
,
InfiniCoreModule
.
get_extra_state
)
is
not
InfiniCoreModule
.
get_extra_state
):
destination
[
extra_state_key
]
=
self
.
get_extra_state
()
destination
[
extra_state_key
]
=
self
.
get_extra_state
()
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination
=
TypeVar
(
'
T_destination
'
,
bound
=
Dict
[
str
,
Any
])
T_destination
=
TypeVar
(
"
T_destination
"
,
bound
=
Dict
[
str
,
Any
])
@
overload
@
overload
def
state_dict
(
self
,
*
,
destination
:
T_destination
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...)
->
T_destination
:
def
state_dict
(
...
self
,
*
,
destination
:
T_destination
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...
)
->
T_destination
:
...
@
overload
@
overload
def
state_dict
(
self
,
*
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...)
->
Dict
[
str
,
Any
]:
def
state_dict
(
...
self
,
*
,
prefix
:
str
=
...,
keep_vars
:
bool
=
...
)
->
Dict
[
str
,
Any
]:
...
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# Also remove the logic for arg parsing together.
# Also remove the logic for arg parsing together.
def
state_dict
(
self
,
*
args
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
*
args
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
r
"""Returns a dictionary containing references to the whole state of the module.
r
"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
Both parameters and persistent buffers (e.g. running averages) are
...
@@ -366,7 +380,7 @@ class InfiniCoreModule:
...
@@ -366,7 +380,7 @@ class InfiniCoreModule:
)
)
if
destination
is
None
:
if
destination
is
None
:
destination
=
args
[
0
]
destination
=
args
[
0
]
if
len
(
args
)
>
1
and
prefix
==
''
:
if
len
(
args
)
>
1
and
prefix
==
""
:
prefix
=
args
[
1
]
prefix
=
args
[
1
]
if
len
(
args
)
>
2
and
keep_vars
is
False
:
if
len
(
args
)
>
2
and
keep_vars
is
False
:
keep_vars
=
args
[
2
]
keep_vars
=
args
[
2
]
...
@@ -382,9 +396,13 @@ class InfiniCoreModule:
...
@@ -382,9 +396,13 @@ class InfiniCoreModule:
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
for
name
,
module
in
self
.
_modules
.
items
():
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
if
module
is
not
None
:
module
.
state_dict
(
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
module
.
state_dict
(
destination
=
destination
,
prefix
=
prefix
+
name
+
"."
,
keep_vars
=
keep_vars
,
)
return
destination
return
destination
def
set_extra_state
(
self
,
state
:
Any
):
def
set_extra_state
(
self
,
state
:
Any
):
"""
"""
This function is called from :func:`load_state_dict` to handle any extra state
This function is called from :func:`load_state_dict` to handle any extra state
...
@@ -398,10 +416,19 @@ class InfiniCoreModule:
...
@@ -398,10 +416,19 @@ class InfiniCoreModule:
raise
RuntimeError
(
raise
RuntimeError
(
"Reached a code path in Module.set_extra_state() that should never be called. "
"Reached a code path in Module.set_extra_state() that should never be called. "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"to report this bug."
)
"to report this bug."
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
def
_load_from_state_dict
(
missing_keys
,
unexpected_keys
,
error_msgs
):
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
r
"""Copies parameters and buffers from :attr:`state_dict` into only
r
"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
...
@@ -433,50 +460,45 @@ class InfiniCoreModule:
...
@@ -433,50 +460,45 @@ class InfiniCoreModule:
list, and will be reported together in
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
:meth:`~torch.nn.Module.load_state_dict`
"""
"""
persistent_buffers
=
{
persistent_buffers
=
{
k
:
v
for
k
,
v
in
self
.
_buffers
.
items
()
if
k
not
in
self
.
_non_persistent_buffers_set
}
k
:
v
local_name_params
=
itertools
.
chain
(
self
.
_parameters
.
items
(),
persistent_buffers
.
items
())
for
k
,
v
in
self
.
_buffers
.
items
()
if
k
not
in
self
.
_non_persistent_buffers_set
}
local_name_params
=
itertools
.
chain
(
self
.
_parameters
.
items
(),
persistent_buffers
.
items
()
)
local_state
=
{
k
:
v
for
k
,
v
in
local_name_params
if
v
is
not
None
}
local_state
=
{
k
:
v
for
k
,
v
in
local_name_params
if
v
is
not
None
}
for
name
,
param
in
local_state
.
items
():
for
name
,
param
in
local_state
.
items
():
key
=
prefix
+
name
key
=
prefix
+
name
if
key
in
state_dict
:
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
input_param
=
state_dict
[
key
]
if
not
torch
.
overrides
.
is_tensor_like
(
input_param
):
error_msgs
.
append
(
'While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.
format
(
key
,
type
(
input_param
)))
continue
# This is used to avoid copying uninitialized parameters into
# input_param must be of type infinicore.Tensor
# non-lazy modules, since they dont have the hook to do the checks
if
not
isinstance
(
input_param
,
Tensor
):
# in such case, it will error when accessing the .shape attribute.
raise
TypeError
(
is_param_lazy
=
torch
.
nn
.
parameter
.
is_lazy
(
param
)
f
"While copying the parameter named
{
key
}
, expected Tensor from checkpoint but received
{
type
(
input_param
)
}
"
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
)
if
not
is_param_lazy
and
len
(
param
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
input_param
=
input_param
[
0
]
if
(
(
param
.
shape
==
input_param
.
shape
)
if
not
is_param_lazy
and
input_param
.
shape
!=
param
.
shape
:
and
(
param
.
dtype
==
input_param
.
dtype
)
# local shape should match the one in checkpoint
and
(
param
.
device
==
input_param
.
device
)
error_msgs
.
append
(
'size mismatch for {}: copying a param with shape {} from checkpoint, '
):
'the shape in current model is {}.'
param
.
copy_
(
input_param
)
.
format
(
key
,
input_param
.
shape
,
param
.
shape
))
else
:
continue
print
(
f
"param '
{
name
}
' don't match input_param '
{
key
}
'"
)
try
:
setattr
(
self
,
name
,
input_param
)
with
torch
.
no_grad
():
param
.
copy_
(
input_param
)
except
Exception
as
ex
:
error_msgs
.
append
(
'While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.
format
(
key
,
param
.
size
(),
input_param
.
size
(),
ex
.
args
))
elif
strict
:
elif
strict
:
missing_keys
.
append
(
key
)
missing_keys
.
append
(
key
)
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
getattr
(
self
.
__class__
,
"set_extra_state"
,
InfiniCoreModule
.
set_extra_state
)
is
not
InfiniCoreModule
.
set_extra_state
:
if
(
getattr
(
self
.
__class__
,
"set_extra_state"
,
InfiniCoreModule
.
set_extra_state
)
is
not
InfiniCoreModule
.
set_extra_state
):
if
extra_state_key
in
state_dict
:
if
extra_state_key
in
state_dict
:
self
.
set_extra_state
(
state_dict
[
extra_state_key
])
self
.
set_extra_state
(
state_dict
[
extra_state_key
])
elif
strict
:
elif
strict
:
...
@@ -486,8 +508,8 @@ class InfiniCoreModule:
...
@@ -486,8 +508,8 @@ class InfiniCoreModule:
if
strict
:
if
strict
:
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
if
key
.
startswith
(
prefix
)
and
key
!=
extra_state_key
:
if
key
.
startswith
(
prefix
):
input_name
=
key
[
len
(
prefix
):].
split
(
"."
,
1
)
input_name
=
key
[
len
(
prefix
)
:].
split
(
"."
,
1
)
# Must be Module if it have attributes
# Must be Module if it have attributes
if
len
(
input_name
)
>
1
:
if
len
(
input_name
)
>
1
:
if
input_name
[
0
]
not
in
self
.
_modules
:
if
input_name
[
0
]
not
in
self
.
_modules
:
...
@@ -495,8 +517,7 @@ class InfiniCoreModule:
...
@@ -495,8 +517,7 @@ class InfiniCoreModule:
elif
input_name
[
0
]
not
in
local_state
:
elif
input_name
[
0
]
not
in
local_state
:
unexpected_keys
.
append
(
key
)
unexpected_keys
.
append
(
key
)
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
):
strict
:
bool
=
True
):
r
"""Copies parameters and buffers from :attr:`state_dict` into
r
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
the keys of :attr:`state_dict` must exactly match the keys returned
...
@@ -520,28 +541,40 @@ class InfiniCoreModule:
...
@@ -520,28 +541,40 @@ class InfiniCoreModule:
``RuntimeError``.
``RuntimeError``.
"""
"""
if
not
isinstance
(
state_dict
,
Mapping
):
if
not
isinstance
(
state_dict
,
Mapping
):
raise
TypeError
(
"Expected state_dict to be dict-like, got {}."
.
format
(
type
(
state_dict
)))
raise
TypeError
(
"Expected state_dict to be dict-like, got {}."
.
format
(
type
(
state_dict
))
)
missing_keys
:
List
[
str
]
=
[]
missing_keys
:
List
[
str
]
=
[]
unexpected_keys
:
List
[
str
]
=
[]
unexpected_keys
:
List
[
str
]
=
[]
error_msgs
:
List
[
str
]
=
[]
error_msgs
:
List
[
str
]
=
[]
# copy state_dict so _load_from_state_dict can modify it
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'
_metadata
'
,
None
)
metadata
=
getattr
(
state_dict
,
"
_metadata
"
,
None
)
state_dict
=
OrderedDict
(
state_dict
)
state_dict
=
OrderedDict
(
state_dict
)
if
metadata
is
not
None
:
if
metadata
is
not
None
:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict
.
_metadata
=
metadata
# type: ignore[attr-defined]
state_dict
.
_metadata
=
metadata
# type: ignore[attr-defined]
def
load
(
module
,
local_state_dict
,
prefix
=
''
):
def
load
(
module
,
local_state_dict
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
module
.
_load_from_state_dict
(
local_state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
local_state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
,
)
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
if
child
is
not
None
:
child_prefix
=
prefix
+
name
+
'.'
child_prefix
=
prefix
+
name
+
"."
child_state_dict
=
{
k
:
v
for
k
,
v
in
local_state_dict
.
items
()
if
k
.
startswith
(
child_prefix
)}
child_state_dict
=
{
load
(
child
,
child_state_dict
,
child_prefix
)
k
:
v
for
k
,
v
in
local_state_dict
.
items
()
if
k
.
startswith
(
child_prefix
)
}
load
(
child
,
child_state_dict
,
child_prefix
)
# noqa: F821
load
(
self
,
state_dict
)
load
(
self
,
state_dict
)
del
load
del
load
...
@@ -549,19 +582,28 @@ class InfiniCoreModule:
...
@@ -549,19 +582,28 @@ class InfiniCoreModule:
if
strict
:
if
strict
:
if
len
(
unexpected_keys
)
>
0
:
if
len
(
unexpected_keys
)
>
0
:
error_msgs
.
insert
(
error_msgs
.
insert
(
0
,
'Unexpected key(s) in state_dict: {}. '
.
format
(
0
,
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
unexpected_keys
)))
"Unexpected key(s) in state_dict: {}. "
.
format
(
", "
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
unexpected_keys
)
),
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
error_msgs
.
insert
(
error_msgs
.
insert
(
0
,
'Missing key(s) in state_dict: {}. '
.
format
(
0
,
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
)))
"Missing key(s) in state_dict: {}. "
.
format
(
", "
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
)
),
)
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)
)
)
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]
]:
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
"Parameter"
]:
r
"""Returns an iterator over module parameters.
r
"""Returns an iterator over module parameters.
Args:
Args:
...
@@ -582,7 +624,9 @@ class InfiniCoreModule:
...
@@ -582,7 +624,9 @@ class InfiniCoreModule:
for
name
,
param
in
self
.
named_parameters
(
recurse
=
recurse
):
for
name
,
param
in
self
.
named_parameters
(
recurse
=
recurse
):
yield
param
yield
param
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]]]:
def
named_parameters
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
"Parameter"
]]:
r
"""Returns an iterator over module parameters, yielding both the
r
"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
name of the parameter as well as the parameter itself.
...
@@ -604,12 +648,12 @@ class InfiniCoreModule:
...
@@ -604,12 +648,12 @@ class InfiniCoreModule:
"""
"""
gen
=
self
.
_named_members
(
gen
=
self
.
_named_members
(
lambda
module
:
module
.
_parameters
.
items
(),
lambda
module
:
module
.
_parameters
.
items
(),
prefix
=
prefix
,
recurse
=
recurse
prefix
=
prefix
,
recurse
=
recurse
)
)
for
elem
in
gen
:
for
elem
in
gen
:
yield
elem
yield
elem
def
buffers
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
Tensor
]:
def
buffers
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Tensor
]:
r
"""Returns an iterator over module buffers.
r
"""Returns an iterator over module buffers.
Args:
Args:
...
@@ -630,7 +674,9 @@ class InfiniCoreModule:
...
@@ -630,7 +674,9 @@ class InfiniCoreModule:
for
name
,
buf
in
self
.
named_buffers
(
recurse
=
recurse
):
for
name
,
buf
in
self
.
named_buffers
(
recurse
=
recurse
):
yield
buf
yield
buf
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
def
named_buffers
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
Tensor
]]:
r
"""Returns an iterator over module buffers, yielding both the
r
"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
name of the buffer as well as the buffer itself.
...
@@ -660,10 +706,10 @@ class InfiniCoreModule:
...
@@ -660,10 +706,10 @@ class InfiniCoreModule:
if
k
in
module
.
_non_persistent_buffers_set
:
if
k
in
module
.
_non_persistent_buffers_set
:
continue
continue
memo
.
add
(
v
)
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
name
=
module_prefix
+
(
"."
if
module_prefix
else
""
)
+
k
yield
(
name
,
v
)
yield
(
name
,
v
)
def
_named_members
(
self
,
get_members_fn
,
prefix
=
''
,
recurse
=
True
):
def
_named_members
(
self
,
get_members_fn
,
prefix
=
""
,
recurse
=
True
):
r
"""Helper method to yield members with their names."""
r
"""Helper method to yield members with their names."""
memo
=
set
()
memo
=
set
()
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
...
@@ -673,10 +719,10 @@ class InfiniCoreModule:
...
@@ -673,10 +719,10 @@ class InfiniCoreModule:
if
v
is
None
or
v
in
memo
:
if
v
is
None
or
v
in
memo
:
continue
continue
memo
.
add
(
v
)
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
name
=
module_prefix
+
(
"."
if
module_prefix
else
""
)
+
k
yield
(
name
,
v
)
yield
(
name
,
v
)
def
modules
(
self
)
->
Iterator
[
'
InfiniCoreModule
'
]:
def
modules
(
self
)
->
Iterator
[
"
InfiniCoreModule
"
]:
r
"""Returns an iterator over all modules in the network.
r
"""Returns an iterator over all modules in the network.
Yields:
Yields:
...
@@ -704,7 +750,12 @@ class InfiniCoreModule:
...
@@ -704,7 +750,12 @@ class InfiniCoreModule:
for
name
,
module
in
self
.
named_modules
():
for
name
,
module
in
self
.
named_modules
():
yield
module
yield
module
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
'InfiniCoreModule'
]]
=
None
,
prefix
:
str
=
''
,
remove_duplicate
:
bool
=
True
):
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
"InfiniCoreModule"
]]
=
None
,
prefix
:
str
=
""
,
remove_duplicate
:
bool
=
True
,
):
r
"""Returns an iterator over all modules in the network, yielding
r
"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
both the name of the module as well as the module itself.
...
@@ -746,18 +797,20 @@ class InfiniCoreModule:
...
@@ -746,18 +797,20 @@ class InfiniCoreModule:
for
name
,
module
in
self
.
_modules
.
items
():
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
None
:
if
module
is
None
:
continue
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
name
submodule_prefix
=
prefix
+
(
"."
if
prefix
else
""
)
+
name
# Handle both InfiniCoreModule and torch.nn.Module
# Handle both InfiniCoreModule and torch.nn.Module
if
isinstance
(
module
,
InfiniCoreModule
):
if
isinstance
(
module
,
InfiniCoreModule
):
for
m
in
module
.
named_modules
(
memo
,
submodule_prefix
,
remove_duplicate
):
for
m
in
module
.
named_modules
(
memo
,
submodule_prefix
,
remove_duplicate
):
yield
m
yield
m
elif
isinstance
(
module
,
torch
.
nn
.
Module
):
elif
isinstance
(
module
,
infinicore
.
nn
.
Module
):
# For torch.nn.Module, use its named_modules method
# For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples
# torch.nn.Module.named_modules returns (name, module) tuples
for
sub_name
,
sub_module
in
module
.
named_modules
(
prefix
=
submodule_prefix
,
remove_duplicate
=
remove_duplicate
):
for
sub_name
,
sub_module
in
module
.
named_modules
(
prefix
=
submodule_prefix
,
remove_duplicate
=
remove_duplicate
):
yield
(
sub_name
,
sub_module
)
yield
(
sub_name
,
sub_module
)
def
children
(
self
)
->
Iterator
[
Union
[
'
InfiniCoreModule
'
,
torch
.
nn
.
Module
]
]:
def
children
(
self
)
->
Iterator
[
"
InfiniCoreModule
"
]:
r
"""Returns an iterator over immediate children modules.
r
"""Returns an iterator over immediate children modules.
Yields:
Yields:
...
@@ -766,7 +819,9 @@ class InfiniCoreModule:
...
@@ -766,7 +819,9 @@ class InfiniCoreModule:
for
name
,
module
in
self
.
named_children
():
for
name
,
module
in
self
.
named_children
():
yield
module
yield
module
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
Union
[
'InfiniCoreModule'
,
torch
.
nn
.
Module
]]]:
def
named_children
(
self
,
)
->
Iterator
[
Tuple
[
str
,
"InfiniCoreModule"
]]:
r
"""Returns an iterator over immediate children modules, yielding both
r
"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
the name of the module as well as the module itself.
...
@@ -787,169 +842,16 @@ class InfiniCoreModule:
...
@@ -787,169 +842,16 @@ class InfiniCoreModule:
memo
.
add
(
module
)
memo
.
add
(
module
)
yield
name
,
module
yield
name
,
module
def
train
(
self
:
T
,
mode
:
bool
=
True
)
->
T
:
r
"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if
not
isinstance
(
mode
,
bool
):
raise
ValueError
(
"training mode is expected to be boolean"
)
self
.
training
=
mode
for
module
in
self
.
children
():
module
.
train
(
mode
)
return
self
def
eval
(
self
:
T
)
->
T
:
def
eval
(
self
:
T
)
->
T
:
r
"""Sets the module in evaluation mode.
r
"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Returns:
Module: self
Module: self
"""
"""
return
self
.
train
(
False
)
pass
def
_apply
(
self
,
fn
,
recurse
=
True
):
def
_apply
(
self
,
fn
,
recurse
=
True
):
if
recurse
:
raise
KeyError
(
"not support"
)
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
compute_should_use_set_data
(
tensor
,
tensor_applied
):
if
torch
.
_has_compatible_shallow_copy_type
(
tensor
,
tensor_applied
):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return
not
torch
.
__future__
.
get_overwrite_module_params_on_conversion
()
else
:
return
False
should_use_swap_tensors
=
torch
.
__future__
.
get_swap_module_params_on_conversion
()
# Import Parameter (InfiniCoreParameter) for type checking and creation
from
.parameter
import
InfiniCoreParameter
as
Parameter
for
key
,
param
in
self
.
_parameters
.
items
():
if
param
is
None
:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with
torch
.
no_grad
():
param_applied
=
fn
(
param
)
p_should_use_set_data
=
compute_should_use_set_data
(
param
,
param_applied
)
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors
=
should_use_swap_tensors
or
is_traceable_wrapper_subclass
(
param_applied
)
# Determine the Parameter class to use based on the original parameter type
is_infinicore_param
=
isinstance
(
param
,
Parameter
)
ParamClass
=
Parameter
if
is_infinicore_param
else
torch
.
nn
.
Parameter
param_grad
=
param
.
grad
if
p_should_use_swap_tensors
:
try
:
if
param_grad
is
not
None
:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param
.
grad
=
None
param_applied
=
ParamClass
(
param_applied
,
requires_grad
=
param
.
requires_grad
)
torch
.
utils
.
swap_tensors
(
param
,
param_applied
)
except
Exception
as
e
:
if
param_grad
is
not
None
:
param
.
grad
=
param_grad
raise
RuntimeError
(
f
"_apply(): Couldn't swap
{
self
.
_get_name
()
}
.
{
key
}
"
)
from
e
out_param
=
param
elif
p_should_use_set_data
:
param
.
data
=
param_applied
out_param
=
param
else
:
assert
isinstance
(
param
,
(
torch
.
nn
.
Parameter
,
Parameter
))
assert
param
.
is_leaf
out_param
=
ParamClass
(
param_applied
,
param
.
requires_grad
)
self
.
_parameters
[
key
]
=
out_param
if
param_grad
is
not
None
:
with
torch
.
no_grad
():
grad_applied
=
fn
(
param_grad
)
g_should_use_set_data
=
compute_should_use_set_data
(
param_grad
,
grad_applied
)
if
p_should_use_swap_tensors
:
grad_applied
.
requires_grad_
(
param_grad
.
requires_grad
)
try
:
torch
.
utils
.
swap_tensors
(
param_grad
,
grad_applied
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"_apply(): Couldn't swap
{
self
.
_get_name
()
}
.
{
key
}
.grad"
)
from
e
out_param
.
grad
=
param_grad
elif
g_should_use_set_data
:
assert
out_param
.
grad
is
not
None
out_param
.
grad
.
data
=
grad_applied
else
:
assert
param_grad
.
is_leaf
out_param
.
grad
=
grad_applied
.
requires_grad_
(
param_grad
.
requires_grad
)
for
key
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
:
self
.
_buffers
[
key
]
=
fn
(
buf
)
return
self
def
to
(
self
,
*
args
,
**
kwargs
):
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
raise
KeyError
(
"not support"
)
if
dtype
is
not
None
:
if
not
(
dtype
.
is_floating_point
or
dtype
.
is_complex
):
raise
TypeError
(
'nn.Module.to only accepts floating point or complex '
f
'dtypes, but got desired dtype=
{
dtype
}
'
)
if
dtype
.
is_complex
:
warnings
.
warn
(
"Complex modules are a new feature under active development whose design may change, "
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
)
def
convert
(
t
):
try
:
if
convert_to_format
is
not
None
and
t
.
dim
()
in
(
4
,
5
):
return
t
.
to
(
device
,
dtype
if
t
.
is_floating_point
()
or
t
.
is_complex
()
else
None
,
non_blocking
,
memory_format
=
convert_to_format
,
)
return
t
.
to
(
device
,
dtype
if
t
.
is_floating_point
()
or
t
.
is_complex
()
else
None
,
non_blocking
,
)
except
NotImplementedError
as
e
:
if
str
(
e
)
==
"Cannot copy out of meta tensor; no data!"
:
raise
NotImplementedError
(
f
"
{
e
}
Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
f
"when moving module from meta to a different device."
)
from
None
else
:
raise
return
self
.
_apply
(
convert
)
python/infinicore/nn/modules/parameter.py
deleted
100644 → 0
View file @
f6107946
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
import
torch
from
typing
import
Optional
from
collections
import
OrderedDict
class
InfiniCoreParameter
(
torch
.
Tensor
):
r
"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`InfiniCoreModule` s - when they're
assigned as Module attributes they are automatically added to the list of
its parameters, and will appear e.g. in :meth:`~InfiniCoreModule.parameters` iterator.
Assigning a Tensor doesn't have such effect. This is because one might
want to cache some temporary state, like last hidden state of the RNN, in
the model. If there was no such class as :class:`InfiniCoreParameter`, these
temporaries would get registered too.
Args:
data (Tensor, optional): parameter tensor. If None, creates an empty tensor.
requires_grad (bool, optional): if the parameter requires gradient. Note that
the torch.no_grad() context does NOT affect the default behavior of
Parameter creation--the Parameter will still have `requires_grad=True` in
:class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more
details. Default: `True`
Example::
>>> import torch
>>> from infinicore.nn.modules import InfiniCoreModule, InfiniCoreParameter
>>>
>>> class MyModule(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.weight = InfiniCoreParameter(torch.randn(10, 5))
... self.bias = InfiniCoreParameter(torch.randn(5))
...
>>> module = MyModule()
>>> for param in module.parameters():
... print(param.shape)
torch.Size([10, 5])
torch.Size([5])
"""
def
__new__
(
cls
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
requires_grad
:
bool
=
True
):
if
data
is
None
:
data
=
torch
.
empty
(
0
)
# Handle standard torch.Tensor or InfiniCoreParameter
if
type
(
data
)
is
torch
.
Tensor
or
type
(
data
)
is
InfiniCoreParameter
:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t
=
data
.
detach
().
requires_grad_
(
requires_grad
)
if
type
(
t
)
is
not
type
(
data
):
raise
RuntimeError
(
f
"Creating a InfiniCoreParameter from an instance of type
{
type
(
data
).
__name__
}
"
"requires that detach() returns an instance of the same type, but return "
f
"type
{
type
(
t
).
__name__
}
was found instead. To use the type as a "
"InfiniCoreParameter, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation."
)
t
.
_is_param
=
True
return
t
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
# are still considered that custom tensor type and these methods will not be called for them.
def
__deepcopy__
(
self
,
memo
):
if
id
(
self
)
in
memo
:
return
memo
[
id
(
self
)]
else
:
result
=
type
(
self
)(
self
.
data
.
clone
(
memory_format
=
torch
.
preserve_format
),
self
.
requires_grad
)
memo
[
id
(
self
)]
=
result
return
result
def
__repr__
(
self
):
return
"InfiniCoreParameter containing:
\n
"
+
super
().
__repr__
()
def
__reduce_ex__
(
self
,
proto
):
# Simplified version for serialization
# In a full implementation, you might want to handle hooks and state
state
=
getattr
(
self
,
'_state'
,
None
)
hooks
=
OrderedDict
()
if
not
state
:
return
(
_rebuild_parameter
,
(
self
.
data
,
self
.
requires_grad
,
hooks
),
)
return
(
_rebuild_parameter_with_state
,
(
self
.
data
,
self
.
requires_grad
,
hooks
,
state
),
)
# Note: __torch_function__ is handled by the Tensor base class
# We don't need to override it for standard Parameter behavior
def
_rebuild_parameter
(
data
,
requires_grad
,
hooks
):
"""Rebuild a parameter from serialized data."""
param
=
InfiniCoreParameter
(
data
,
requires_grad
)
# Apply hooks if any (simplified - full implementation would restore hooks)
return
param
def
_rebuild_parameter_with_state
(
data
,
requires_grad
,
hooks
,
state
):
"""Rebuild a parameter with extra state from serialized data."""
param
=
InfiniCoreParameter
(
data
,
requires_grad
)
param
.
_state
=
state
# Apply hooks if any (simplified - full implementation would restore hooks)
return
param
python/infinicore/nn/parameter.py
0 → 100644
View file @
ee722eb9
# Copyright (c) 2025, InfiniCore
#
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
# implementation, which is licensed under the BSD 3-Clause License.
#
# The modifications include adaptations for the InfiniCore framework.
#
# Original PyTorch source:
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
#
# Referencing PyTorch v2.4.0
#
# The use of this file is governed by the BSD 3-Clause License.
from
..tensor
import
Tensor
class
InfiniCoreParameter
(
Tensor
):
r
"""A kind of Tensor that is to be considered a module parameter."""
def
__init__
(
self
,
data
=
None
):
if
not
isinstance
(
data
,
Tensor
):
raise
ValueError
(
"The `data` variable must be of type `infinicore.Tensor`."
)
super
().
__init__
(
data
.
_underlying
)
def
__repr__
(
self
):
return
"Parameter containing:
\n
"
+
super
().
__repr__
()
def
__deepcopy__
(
self
,
memo
):
raise
ValueError
(
"not supported!"
)
def
__reduce_ex__
(
self
,
proto
):
raise
ValueError
(
"not supported!"
)
test/infinicore/infinicore_module_list_test.py
deleted
100644 → 0
View file @
f6107946
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
# 使用临时目录,如果不存在则自动创建
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_modulelist_with_param.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class
TorchModuleListNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 torch.nn.ModuleList
self
.
layers
=
nn
.
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchModuleListNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
print
(
"✓ PyTorch 模型已保存"
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchModuleListNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
print
(
"✓ Torch 输出:"
,
torch_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from
nn.modules
import
Module
,
ModuleList
class
InfiniCoreModuleListNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 ModuleList
self
.
layers
=
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreModuleListNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
print
(
"✓ InfiniCore 输出:"
,
infinicore_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 4. 对比结果
# ============================================================
diff
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"✓ ModuleList 与 Torch 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ ModuleList 与 Torch 精度一致."
)
else
:
print
(
"✗ ModuleList 与 Torch 精度存在差异."
)
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print
(
"
\n
=== 测试 ModuleList 基本功能 ==="
)
# 测试 1: 创建和访问
module_list
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
print
(
f
"✓ 创建 ModuleList,长度:
{
len
(
module_list
)
}
"
)
print
(
f
"✓ 访问第一个模块:
{
type
(
module_list
[
0
]).
__name__
}
"
)
print
(
f
"✓ 访问第二个模块:
{
type
(
module_list
[
1
]).
__name__
}
"
)
# 测试 2: append
module_list
.
append
(
nn
.
Softmax
(
dim
=-
1
))
print
(
f
"✓ append 后长度:
{
len
(
module_list
)
}
"
)
# 测试 3: extend
module_list
.
extend
([
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
5
,
1
)])
print
(
f
"✓ extend 后长度:
{
len
(
module_list
)
}
"
)
# 测试 4: 迭代
print
(
"✓ 迭代 ModuleList:"
)
for
i
,
module
in
enumerate
(
module_list
):
print
(
f
" [
{
i
}
]
{
type
(
module
).
__name__
}
"
)
# 测试 5: 索引访问
print
(
f
"✓ 索引访问 module_list[0]:
{
type
(
module_list
[
0
]).
__name__
}
"
)
# 测试 6: state_dict
state_dict
=
module_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
print
(
f
"✓ state_dict 包含模块参数:
{
any
(
'0.'
in
k
for
k
in
state_dict
.
keys
())
}
"
)
# 测试 7: 使用 ModuleList 的模型
class
TestNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
test_model
=
TestNet
()
test_input
=
torch
.
randn
(
2
,
10
)
test_output
=
test_model
.
forward
(
test_input
)
print
(
f
"✓ TestNet 输入形状:
{
test_input
.
shape
}
, 输出形状:
{
test_output
.
shape
}
"
)
# 测试 8: __add__ 方法
ml1
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
ml2
=
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Sigmoid
()])
ml3
=
ml1
+
ml2
print
(
f
"✓ __add__ 方法测试:
{
len
(
ml1
)
}
+
{
len
(
ml2
)
}
=
{
len
(
ml3
)
}
"
)
assert
len
(
ml3
)
==
4
,
"合并后的长度应该为 4"
# 测试 9: pop 方法
ml4
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
(),
nn
.
Linear
(
5
,
3
)])
popped
=
ml4
.
pop
()
print
(
f
"✓ pop 方法测试: 弹出后长度
{
len
(
ml4
)
}
, 弹出模块类型
{
type
(
popped
).
__name__
}
"
)
assert
len
(
ml4
)
==
2
,
"pop 后长度应该为 2"
assert
isinstance
(
popped
,
nn
.
Linear
),
"弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
repr_str
=
repr
(
ml5
)
print
(
f
"✓ __repr__ 方法测试: 输出包含类名和模块信息"
)
assert
"ModuleList"
in
repr_str
or
"InfiniCoreModuleList"
in
repr_str
,
"repr 应该包含类名"
assert
"Linear"
in
repr_str
,
"repr 应该包含模块信息"
print
(
repr_str
)
print
(
"
\n
=== 所有测试通过! ==="
)
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print
(
"
\n
=== 前向传播集成测试 ==="
)
# 使用 ModuleList 创建一个简单的模型
class
TorchModuleListModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
class
InfiniCoreModuleListModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# 创建模型
torch_model_forward
=
TorchModuleListModel
()
infinicore_model_forward
=
InfiniCoreModuleListModel
()
# 复制权重(确保初始权重一致)
infinicore_model_forward
.
load_state_dict
(
torch_model_forward
.
state_dict
(),
strict
=
False
)
# 设置为评估模式
torch_model_forward
.
eval
()
infinicore_model_forward
.
eval
()
# 创建测试输入
test_input
=
torch
.
randn
(
2
,
10
)
# 前向传播
with
torch
.
no_grad
():
torch_output
=
torch_model_forward
(
test_input
)
infinicore_output
=
infinicore_model_forward
.
forward
(
test_input
)
# 对比结果
diff
=
(
infinicore_output
-
torch_output
).
abs
().
max
().
item
()
print
(
f
"✓ 前向传播测试 - 输入形状:
{
test_input
.
shape
}
"
)
print
(
f
"✓ Torch 输出形状:
{
torch_output
.
shape
}
, 均值:
{
torch_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ InfiniCore 输出形状:
{
infinicore_output
.
shape
}
, 均值:
{
infinicore_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!"
)
else
:
print
(
"✗ 前向传播集成测试失败:存在差异"
)
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print
(
"
\n
=== 混合模块兼容性测试 ==="
)
# 创建一个自定义的 InfiniCore 模块
class
CustomLinear
(
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
randn
(
out_features
))
def
forward
(
self
,
x
):
return
x
@
self
.
weight
.
t
()
+
self
.
bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
# PyTorch 模块
CustomLinear
(
5
,
3
),
# 自定义 InfiniCore 模块
nn
.
ReLU
(),
# PyTorch 模块
])
print
(
f
"✓ 创建混合 ModuleList,长度:
{
len
(
mixed_list
)
}
"
)
print
(
f
"✓ 模块类型:
{
[
type
(
m
).
__name__
for
m
in
mixed_list
]
}
"
)
# 测试参数注册
param_count
=
sum
(
1
for
_
in
mixed_list
.
parameters
())
print
(
f
"✓ 参数数量:
{
param_count
}
"
)
assert
param_count
==
4
,
f
"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为
{
param_count
}
"
# 测试 state_dict
mixed_state_dict
=
mixed_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
mixed_state_dict
)
}
"
)
assert
len
(
mixed_state_dict
)
>=
4
,
"state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed
=
torch
.
randn
(
2
,
10
)
with
torch
.
no_grad
():
x
=
test_input_mixed
for
module
in
mixed_list
:
x
=
module
.
forward
(
x
)
print
(
f
"✓ 混合模块前向传播成功,输出形状:
{
x
.
shape
}
"
)
print
(
"✓ 混合模块兼容性测试通过!"
)
test/infinicore/infinicore_nn_test.py
deleted
100644 → 0
View file @
f6107946
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_convnet_with_param.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型
# ============================================================
print
(
"===== 开始 CPU 一致性测试 ====="
)
class
TorchConvNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 主体网络
self
.
conv1
=
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv2
=
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn2
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv3
=
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
)
self
.
relu
=
nn
.
ReLU
()
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
# 注册一个 buffer
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
conv3
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchConvNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchConvNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
# ============================================================
# 3. 使用 infiniCore.nn.module 加载并推理
# ============================================================
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
from
nn
import
Module
class
InfiniCoreConvNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv2
=
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
bn2
=
nn
.
BatchNorm2d
(
hidden_ch
)
self
.
conv3
=
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
)
self
.
relu
=
nn
.
ReLU
()
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
conv3
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreConvNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
# ============================================================
# 4. 对比结果
# ============================================================
diff_cpu
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"InfiniCoreModule 与 Torch (CPU) 最大误差:
{
diff_cpu
:.
6
e
}
"
)
if
diff_cpu
<
1e-9
:
print
(
"CPU 模式下 InfiniCore 与 Torch 输出完全一致."
)
else
:
print
(
"CPU 模式下输出存在差异."
)
# ============================================================
# 5. GPU 一致性测试(可选)
# ============================================================
if
torch
.
cuda
.
is_available
():
print
(
"
\n
===== 开始 GPU 一致性测试 ====="
)
# 将模型与输入都迁移到 GPU
torch_model_infer_gpu
=
TorchConvNet
().
to
(
"cuda"
)
torch_model_infer_gpu
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer_gpu
.
eval
()
infinicore_model_infer_gpu
=
InfiniCoreConvNet
().
to
(
"cuda"
)
infinicore_model_infer_gpu
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer_gpu
.
eval
()
# 生成 GPU 输入
input_gpu
=
input
.
to
(
"cuda"
)
# 分别前向推理
torch_out_gpu
=
torch_model_infer_gpu
(
input_gpu
)
infinicore_out_gpu
=
infinicore_model_infer_gpu
.
forward
(
input_gpu
)
# 结果比较
diff_gpu
=
(
infinicore_out_gpu
-
torch_out_gpu
).
abs
().
max
().
item
()
print
(
f
"InfiniCoreModule 与 Torch (GPU) 最大误差:
{
diff_gpu
:.
6
e
}
"
)
if
diff_gpu
<
1e-9
:
print
(
"GPU 模式下 InfiniCore 与 Torch 输出完全一致."
)
else
:
print
(
"GPU 模式下输出存在差异."
)
else
:
print
(
"
\n
未检测到 GPU,跳过 GPU 一致性测试。"
)
\ No newline at end of file
test/infinicore/nn/Module.py
0 → 100644
View file @
ee722eb9
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../python/infinicore"
))
)
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../tmp"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_convnet_with_param.safetensors"
)
import
infinicore
# noqa: E402
from
infinicore.nn
import
Module
# noqa: E402
# ============================================================
# 1. 定义模型
# ============================================================
device_str
=
"cuda"
class
InfiniCoreNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
infinicore
.
nn
.
Parameter
(
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
),
)
)
self
.
b
=
infinicore
.
nn
.
Parameter
(
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
),
)
)
def
forward
(
self
):
return
infinicore
.
add
(
self
.
a
,
self
.
b
)
infinicore_model_infer
=
InfiniCoreNet
()
# ============================================================
# 2. 加载权重
# ============================================================
params_dict
=
{
"a"
:
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
,
0
)
),
"b"
:
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
,
0
)
),
}
infinicore_model_infer
.
load_state_dict
(
params_dict
)
# ============================================================
# 3. 计算
# ============================================================
infinicore_model_out
=
infinicore_model_infer
()
ref_out
=
infinicore
.
add
(
params_dict
[
"a"
],
params_dict
[
"b"
])
# ============================================================
# 4. 对比结果
# ============================================================
print
(
"InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 "
)
infinicore_model_out
.
debug
()
ref_out
.
debug
()
# ============================================================
# 5. to测试,buffer测试
# ============================================================
# 等待添加
test/infinicore/nn/ModuleList.py
0 → 100644
View file @
ee722eb9
import
os
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
safetensors
import
safetensors.torch
import
torch
import
torch.nn
as
nn
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../python/infinicore"
))
)
# 使用临时目录,如果不存在则自动创建
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../tmp"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_modulelist_with_param.safetensors"
)
def
test
():
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class
TorchModuleListNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 torch.nn.ModuleList
self
.
layers
=
nn
.
ModuleList
(
[
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
]
)
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchModuleListNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
print
(
"✓ PyTorch 模型已保存"
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchModuleListNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
print
(
"✓ Torch 输出:"
,
torch_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from
nn.modules
import
Module
,
ModuleList
class
InfiniCoreModuleListNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 ModuleList
self
.
layers
=
ModuleList
(
[
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
]
)
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreModuleListNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
print
(
"✓ InfiniCore 输出:"
,
infinicore_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 4. 对比结果
# ============================================================
diff
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"✓ ModuleList 与 Torch 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ ModuleList 与 Torch 精度一致."
)
else
:
print
(
"✗ ModuleList 与 Torch 精度存在差异."
)
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print
(
"
\n
=== 测试 ModuleList 基本功能 ==="
)
# 测试 1: 创建和访问
module_list
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)])
print
(
f
"✓ 创建 ModuleList,长度:
{
len
(
module_list
)
}
"
)
print
(
f
"✓ 访问第一个模块:
{
type
(
module_list
[
0
]).
__name__
}
"
)
print
(
f
"✓ 访问第二个模块:
{
type
(
module_list
[
1
]).
__name__
}
"
)
# 测试 2: append
module_list
.
append
(
nn
.
Softmax
(
dim
=-
1
))
print
(
f
"✓ append 后长度:
{
len
(
module_list
)
}
"
)
# 测试 3: extend
module_list
.
extend
([
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
5
,
1
)])
print
(
f
"✓ extend 后长度:
{
len
(
module_list
)
}
"
)
# 测试 4: 迭代
print
(
"✓ 迭代 ModuleList:"
)
for
i
,
module
in
enumerate
(
module_list
):
print
(
f
" [
{
i
}
]
{
type
(
module
).
__name__
}
"
)
# 测试 5: 索引访问
print
(
f
"✓ 索引访问 module_list[0]:
{
type
(
module_list
[
0
]).
__name__
}
"
)
# 测试 6: state_dict
state_dict
=
module_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
print
(
f
"✓ state_dict 包含模块参数:
{
any
(
'0.'
in
k
for
k
in
state_dict
.
keys
())
}
"
)
# 测试 7: 使用 ModuleList 的模型
class
TestNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
test_model
=
TestNet
()
test_input
=
torch
.
randn
(
2
,
10
)
test_output
=
test_model
.
forward
(
test_input
)
print
(
f
"✓ TestNet 输入形状:
{
test_input
.
shape
}
, 输出形状:
{
test_output
.
shape
}
"
)
# 测试 8: __add__ 方法
ml1
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
ml2
=
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Sigmoid
()])
ml3
=
ml1
+
ml2
print
(
f
"✓ __add__ 方法测试:
{
len
(
ml1
)
}
+
{
len
(
ml2
)
}
=
{
len
(
ml3
)
}
"
)
assert
len
(
ml3
)
==
4
,
"合并后的长度应该为 4"
# 测试 9: pop 方法
ml4
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
(),
nn
.
Linear
(
5
,
3
)])
popped
=
ml4
.
pop
()
print
(
f
"✓ pop 方法测试: 弹出后长度
{
len
(
ml4
)
}
, 弹出模块类型
{
type
(
popped
).
__name__
}
"
)
assert
len
(
ml4
)
==
2
,
"pop 后长度应该为 2"
assert
isinstance
(
popped
,
nn
.
Linear
),
"弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
repr_str
=
repr
(
ml5
)
print
(
f
"✓ __repr__ 方法测试: 输出包含类名和模块信息"
)
assert
"ModuleList"
in
repr_str
or
"InfiniCoreModuleList"
in
repr_str
,
(
"repr 应该包含类名"
)
assert
"Linear"
in
repr_str
,
"repr 应该包含模块信息"
print
(
repr_str
)
print
(
"
\n
=== 所有测试通过! ==="
)
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print
(
"
\n
=== 前向传播集成测试 ==="
)
# 使用 ModuleList 创建一个简单的模型
class
TorchModuleListModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
(
[
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)]
)
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
class
InfiniCoreModuleListModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# 创建模型
torch_model_forward
=
TorchModuleListModel
()
infinicore_model_forward
=
InfiniCoreModuleListModel
()
# 复制权重(确保初始权重一致)
infinicore_model_forward
.
load_state_dict
(
torch_model_forward
.
state_dict
(),
strict
=
False
)
# 设置为评估模式
torch_model_forward
.
eval
()
infinicore_model_forward
.
eval
()
# 创建测试输入
test_input
=
torch
.
randn
(
2
,
10
)
# 前向传播
with
torch
.
no_grad
():
torch_output
=
torch_model_forward
(
test_input
)
infinicore_output
=
infinicore_model_forward
.
forward
(
test_input
)
# 对比结果
diff
=
(
infinicore_output
-
torch_output
).
abs
().
max
().
item
()
print
(
f
"✓ 前向传播测试 - 输入形状:
{
test_input
.
shape
}
"
)
print
(
f
"✓ Torch 输出形状:
{
torch_output
.
shape
}
, 均值:
{
torch_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ InfiniCore 输出形状:
{
infinicore_output
.
shape
}
, 均值:
{
infinicore_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!"
)
else
:
print
(
"✗ 前向传播集成测试失败:存在差异"
)
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print
(
"
\n
=== 混合模块兼容性测试 ==="
)
# 创建一个自定义的 InfiniCore 模块
class
CustomLinear
(
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
randn
(
out_features
))
def
forward
(
self
,
x
):
return
x
@
self
.
weight
.
t
()
+
self
.
bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list
=
ModuleList
(
[
nn
.
Linear
(
10
,
5
),
# PyTorch 模块
CustomLinear
(
5
,
3
),
# 自定义 InfiniCore 模块
nn
.
ReLU
(),
# PyTorch 模块
]
)
print
(
f
"✓ 创建混合 ModuleList,长度:
{
len
(
mixed_list
)
}
"
)
print
(
f
"✓ 模块类型:
{
[
type
(
m
).
__name__
for
m
in
mixed_list
]
}
"
)
# 测试参数注册
param_count
=
sum
(
1
for
_
in
mixed_list
.
parameters
())
print
(
f
"✓ 参数数量:
{
param_count
}
"
)
assert
param_count
==
4
,
(
f
"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为
{
param_count
}
"
)
# 测试 state_dict
mixed_state_dict
=
mixed_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
mixed_state_dict
)
}
"
)
assert
len
(
mixed_state_dict
)
>=
4
,
"state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed
=
torch
.
randn
(
2
,
10
)
with
torch
.
no_grad
():
x
=
test_input_mixed
for
module
in
mixed_list
:
x
=
module
.
forward
(
x
)
print
(
f
"✓ 混合模块前向传播成功,输出形状:
{
x
.
shape
}
"
)
print
(
"✓ 混合模块兼容性测试通过!"
)
test/infinicore/
infinicore_p
arameter
_test
.py
→
test/infinicore/
nn/P
arameter.py
View file @
ee722eb9
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
# ============================================================
import
sys
import
os
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
import
torch
import
torch.nn
as
nn
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../python/infinicore"
))
)
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../tmp"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"infinicore_parameter_test.safetensors"
)
save_path
=
os
.
path
.
join
(
save_dir
,
"infinicore_parameter_test.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.Parameter)
# ============================================================
class
TorchParameterNet
(
nn
.
Module
):
import
infinicore
# noqa: E402
def
__init__
(
self
,
in_features
=
10
,
out_features
=
5
):
from
infinicore.nn
import
Module
,
Parameter
# noqa: E402
device_str
=
"cuda"
class
InfiniCoreParameterNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
a
=
infinicore
.
nn
.
Parameter
(
self
.
bias
=
nn
.
Parameter
(
torch
.
randn
(
out_features
))
infinicore
.
empty
(
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
"cpu"
,
0
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
(
x
@
self
.
weight
.
t
()
+
self
.
bias
)
*
self
.
scale
+
self
.
offset
return
infinicore
.
add
(
self
.
a
,
x
)
# ===== 保存 Torch 模型 =====
torch_model
=
TorchParameterNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
print
(
"✓ PyTorch 模型已保存"
)
infinicore_model_infer
=
InfiniCoreParameterNet
()
# ============================================================
# ============================================================
# 2.
使用 torch 方式加载并推理
# 2.
加载权重
# ============================================================
# ============================================================
params_dict
=
{
"a"
:
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
,
0
)
)
}
infinicore_model_infer
.
load_state_dict
(
params_dict
)
torch_model_infer
=
TorchParameterNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
randn
(
2
,
10
)
torch_model_out
=
torch_model_infer
(
input
)
print
(
"✓ Torch 输出:"
,
torch_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# ============================================================
# 3.
使用 Parameter 加载并推理
# 3.
计算
# ============================================================
# ============================================================
x
=
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
,
0
)
)
from
nn.modules
import
Module
,
Parameter
infinicore_model_out
=
infinicore_model_infer
(
x
)
ref_out
=
infinicore
.
add
(
params_dict
[
"a"
],
x
)
class
InfiniCoreParameterNet
(
Module
):
def
__init__
(
self
,
in_features
=
10
,
out_features
=
5
):
super
().
__init__
()
# 使用 Parameter 替代 torch.nn.Parameter
self
.
weight
=
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
bias
=
Parameter
(
torch
.
randn
(
out_features
))
self
.
scale
=
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
return
(
x
@
self
.
weight
.
t
()
+
self
.
bias
)
*
self
.
scale
+
self
.
offset
# ===== 使用 InfiniCoreParameterNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreParameterNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
print
(
"✓ InfiniCore 输出:"
,
infinicore_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# ============================================================
# 4. 对比结果
# 4. 对比结果
# ============================================================
# ============================================================
print
(
"InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 "
)
infinicore_model_out
.
debug
()
ref_out
.
debug
()
diff
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"✓ Parameter 与 Torch 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ Parameter 与 Torch 精度一致."
)
else
:
print
(
"✗ Parameter 与 Torch 精度存在差异."
)
# ============================================================
# ============================================================
# 5. 测试 Parameter 的基本功能
# 5. 测试 Parameter 的基本功能
...
@@ -93,28 +73,37 @@ else:
...
@@ -93,28 +73,37 @@ else:
print
(
"
\n
=== 测试 Parameter 基本功能 ==="
)
print
(
"
\n
=== 测试 Parameter 基本功能 ==="
)
# 测试 1: 创建 Parameter
# 测试 1: 创建 Parameter
param1
=
Parameter
(
torch
.
randn
(
5
,
10
))
param1
=
infinicore
.
nn
.
Parameter
(
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
,
0
)
)
)
print
(
f
"✓ 创建 Parameter,形状:
{
param1
.
shape
}
"
)
print
(
f
"✓ 创建 Parameter,形状:
{
param1
.
shape
}
"
)
# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名)
# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名)
from
nn.modules.parameter
import
InfiniCoreParameter
assert
isinstance
(
param1
,
(
Parameter
,
InfiniCoreParameter
)),
"应该是 Parameter 类型"
assert
isinstance
(
param1
,
torch
.
Tensor
),
"应该是 torch.Tensor 的子类"
# 测试 2: requires_grad
assert
isinstance
(
param1
,
infinicore
.
nn
.
Parameter
),
"应该是 Parameter 类型"
param2
=
Parameter
(
torch
.
randn
(
3
,
4
),
requires_grad
=
False
)
assert
isinstance
(
param1
,
infinicore
.
Tensor
),
"应该是 torch.Tensor 的子类"
print
(
f
"✓ 创建 requires_grad=False 的 Parameter:
{
param2
.
requires_grad
}
"
)
assert
not
param2
.
requires_grad
,
"requires_grad 应该为 False"
param3
=
Parameter
(
torch
.
randn
(
3
,
4
),
requires_grad
=
True
)
print
(
f
"✓ 创建 requires_grad=True 的 Parameter:
{
param3
.
requires_grad
}
"
)
assert
param3
.
requires_grad
,
"requires_grad 应该为 True"
# 测试 3: 自动注册到 Module
# 测试 3: 自动注册到 Module
class
TestModule
(
Module
):
class
TestModule
(
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
Parameter
(
torch
.
randn
(
5
,
10
))
self
.
weight
=
infinicore
.
nn
.
Parameter
(
self
.
bias
=
Parameter
(
torch
.
randn
(
5
))
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
),
)
)
self
.
bias
=
infinicore
.
nn
.
Parameter
(
infinicore
.
empty
(
(
1
,
2
,
3
),
dtype
=
infinicore
.
float32
,
device
=
infinicore
.
device
(
device_str
),
)
)
test_module
=
TestModule
()
test_module
=
TestModule
()
param_count
=
sum
(
1
for
_
in
test_module
.
parameters
())
param_count
=
sum
(
1
for
_
in
test_module
.
parameters
())
...
@@ -129,8 +118,8 @@ print("✓ 参数可以通过属性访问")
...
@@ -129,8 +118,8 @@ print("✓ 参数可以通过属性访问")
# 测试 5: state_dict
# 测试 5: state_dict
state_dict
=
test_module
.
state_dict
()
state_dict
=
test_module
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
assert
'
weight
'
in
state_dict
,
"state_dict 应该包含 weight"
assert
"
weight
"
in
state_dict
,
"state_dict 应该包含 weight"
assert
'
bias
'
in
state_dict
,
"state_dict 应该包含 bias"
assert
"
bias
"
in
state_dict
,
"state_dict 应该包含 bias"
print
(
f
"✓ state_dict 键:
{
list
(
state_dict
.
keys
())
}
"
)
print
(
f
"✓ state_dict 键:
{
list
(
state_dict
.
keys
())
}
"
)
# 测试 6: __repr__
# 测试 6: __repr__
...
@@ -139,46 +128,21 @@ print(f"✓ __repr__ 方法: 输出包含类名")
...
@@ -139,46 +128,21 @@ print(f"✓ __repr__ 方法: 输出包含类名")
assert
"Parameter"
in
repr_str
or
"InfiniCoreParameter"
in
repr_str
,
"repr 应该包含类名"
assert
"Parameter"
in
repr_str
or
"InfiniCoreParameter"
in
repr_str
,
"repr 应该包含类名"
print
(
repr_str
[:
100
]
+
"..."
)
print
(
repr_str
[:
100
]
+
"..."
)
# 测试 7: 与 torch.nn.Parameter 兼容性
class
MixedModule
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
torch_param
=
nn
.
Parameter
(
torch
.
randn
(
3
,
4
))
self
.
infinicore_param
=
Parameter
(
torch
.
randn
(
3
,
4
))
mixed_module
=
MixedModule
()
mixed_param_count
=
sum
(
1
for
_
in
mixed_module
.
parameters
())
print
(
f
"✓ 混合使用 torch.nn.Parameter 和 Parameter,参数数量:
{
mixed_param_count
}
"
)
assert
mixed_param_count
==
2
,
f
"应该有 2 个参数,实际为
{
mixed_param_count
}
"
# 测试 8: 前向传播
class
TestModuleWithForward
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
weight
=
Parameter
(
torch
.
randn
(
5
,
10
))
self
.
bias
=
Parameter
(
torch
.
randn
(
5
))
def
forward
(
self
,
x
):
return
x
@
self
.
weight
.
t
()
+
self
.
bias
test_module_forward
=
TestModuleWithForward
()
test_input
=
torch
.
randn
(
2
,
10
)
with
torch
.
no_grad
():
output
=
test_module_forward
.
forward
(
test_input
)
print
(
f
"✓ 前向传播成功,输出形状:
{
output
.
shape
}
"
)
assert
output
.
shape
==
(
2
,
5
),
f
"输出形状应该是 (2, 5),实际为
{
output
.
shape
}
"
# 测试 9: 从 None 创建
# 测试 9: 从 None 创建
param_empty
=
Parameter
(
None
)
# param_empty = Parameter(None)
print
(
f
"✓ 从 None 创建 Parameter,形状:
{
param_empty
.
shape
}
"
)
# print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}")
assert
param_empty
.
shape
==
torch
.
Size
([
0
]),
"从 None 创建应该是空张量"
# assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量"
# 测试 10: 深拷贝
# 测试 10: 深拷贝
import
copy
# import copy
param_copy
=
copy
.
deepcopy
(
param1
)
print
(
f
"✓ 深拷贝 Parameter,形状:
{
param_copy
.
shape
}
"
)
assert
param_copy
.
shape
==
param1
.
shape
,
"深拷贝后形状应该相同"
assert
not
torch
.
equal
(
param_copy
,
param1
)
or
id
(
param_copy
)
!=
id
(
param1
),
"深拷贝应该是新对象"
print
(
"
\n
=== 所有测试通过! ==="
)
# param_copy = copy.deepcopy(param1)
# print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}")
# assert param_copy.shape == param1.shape, "深拷贝后形状应该相同"
# assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), (
# "深拷贝应该是新对象"
# )
print
(
"
\n
=== 所有测试通过! ==="
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment