Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
cad51297
Commit
cad51297
authored
Nov 06, 2025
by
zhuyue
Committed by
zhuyue
Nov 17, 2025
Browse files
Issue/568: feat: add infinicore.nn.InfiniCoreModuleList referencing torch.nn.ModuleList.
add some functions in InfiniCoreModule.
parent
27e57f3d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
728 additions
and
4 deletions
+728
-4
python/infinicore/nn/modules/__init__.py
python/infinicore/nn/modules/__init__.py
+1
-0
python/infinicore/nn/modules/module.py
python/infinicore/nn/modules/module.py
+227
-4
python/infinicore/nn/modules/module_list.py
python/infinicore/nn/modules/module_list.py
+183
-0
test/infinicore/infinicore_module_list_test.py
test/infinicore/infinicore_module_list_test.py
+317
-0
No files found.
python/infinicore/nn/modules/__init__.py
View file @
cad51297
from
.module
import
InfiniCoreModule
as
Module
from
.module
import
InfiniCoreModule
as
Module
from
.module_list
import
InfiniCoreModuleList
as
ModuleList
python/infinicore/nn/modules/module.py
View file @
cad51297
...
@@ -105,7 +105,7 @@ class InfiniCoreModule:
...
@@ -105,7 +105,7 @@ class InfiniCoreModule:
self
.
register_parameter
(
name
,
value
)
self
.
register_parameter
(
name
,
value
)
else
:
else
:
modules
=
self
.
__dict__
.
get
(
"_modules"
)
modules
=
self
.
__dict__
.
get
(
"_modules"
)
if
isinstance
(
value
,
(
torch
.
nn
.
Module
)):
if
isinstance
(
value
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)):
if
modules
is
None
:
if
modules
is
None
:
raise
AttributeError
(
raise
AttributeError
(
"cannot assign module before Module.__init__() call"
"cannot assign module before Module.__init__() call"
...
@@ -181,6 +181,33 @@ class InfiniCoreModule:
...
@@ -181,6 +181,33 @@ class InfiniCoreModule:
self
.
_non_persistent_buffers_set
.
add
(
name
)
self
.
_non_persistent_buffers_set
.
add
(
name
)
def
add_module
(
self
,
name
:
str
,
module
:
Optional
[
torch
.
nn
.
Module
])
->
None
:
r
"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module or None): child module to be added to the module. If
``None``, then operations that run on modules, such as :attr:`eval`,
are ignored. If ``None``, the module is **not** included in the
module's :attr:`children`.
"""
if
not
isinstance
(
name
,
str
):
raise
TypeError
(
f
"module name should be a string. Got
{
torch
.
typename
(
name
)
}
"
)
elif
'.'
in
name
:
raise
KeyError
(
f
"module name can't contain
\"
.
\"
, got:
{
name
}
"
)
elif
name
==
''
:
raise
KeyError
(
"module name can't be empty string
\"\"
"
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_modules
:
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
if
module
is
not
None
and
not
isinstance
(
module
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)):
raise
TypeError
(
f
"
{
torch
.
typename
(
module
)
}
is not a Module subclass"
)
self
.
_modules
[
name
]
=
module
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""Add a parameter to the module.
r
"""Add a parameter to the module.
...
@@ -526,16 +553,212 @@ class InfiniCoreModule:
...
@@ -526,16 +553,212 @@ class InfiniCoreModule:
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
def
children
(
self
)
->
Iterator
[
'InfiniCoreModule'
]:
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
nn
.
Parameter
]:
r
"""Returns an iterator over module parameters.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
... print(type(param), param.size())
"""
for
name
,
param
in
self
.
named_parameters
(
recurse
=
recurse
):
yield
param
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
nn
.
Parameter
]]:
r
"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
... if name in ['bias']:
... print(param.size())
"""
gen
=
self
.
_named_members
(
lambda
module
:
module
.
_parameters
.
items
(),
prefix
=
prefix
,
recurse
=
recurse
)
for
elem
in
gen
:
yield
elem
def
buffers
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
Tensor
]:
r
"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
... print(type(buf), buf.size())
"""
for
name
,
buf
in
self
.
named_buffers
(
recurse
=
recurse
):
yield
buf
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
r
"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
... if name in ['running_mean']:
... print(buf.size())
"""
memo
=
set
()
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
for
module_prefix
,
module
in
modules
:
for
k
,
v
in
module
.
_buffers
.
items
():
if
v
is
None
or
v
in
memo
:
continue
if
k
in
module
.
_non_persistent_buffers_set
:
continue
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
yield
(
name
,
v
)
def
_named_members
(
self
,
get_members_fn
,
prefix
=
''
,
recurse
=
True
):
r
"""Helper method to yield members with their names."""
memo
=
set
()
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
for
module_prefix
,
module
in
modules
:
members
=
get_members_fn
(
module
)
for
k
,
v
in
members
:
if
v
is
None
or
v
in
memo
:
continue
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
yield
(
name
,
v
)
def
modules
(
self
)
->
Iterator
[
'InfiniCoreModule'
]:
r
"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for
name
,
module
in
self
.
named_modules
():
yield
module
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
'InfiniCoreModule'
]]
=
None
,
prefix
:
str
=
''
,
remove_duplicate
:
bool
=
True
):
r
"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if
memo
is
None
:
memo
=
set
()
if
remove_duplicate
:
if
self
in
memo
:
return
memo
.
add
(
self
)
yield
prefix
,
self
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
None
:
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
name
# Handle both InfiniCoreModule and torch.nn.Module
if
isinstance
(
module
,
InfiniCoreModule
):
for
m
in
module
.
named_modules
(
memo
,
submodule_prefix
,
remove_duplicate
):
yield
m
elif
isinstance
(
module
,
torch
.
nn
.
Module
):
# For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples
for
sub_name
,
sub_module
in
module
.
named_modules
(
prefix
=
submodule_prefix
,
remove_duplicate
=
remove_duplicate
):
yield
(
sub_name
,
sub_module
)
def
children
(
self
)
->
Iterator
[
Union
[
'InfiniCoreModule'
,
torch
.
nn
.
Module
]]:
r
"""Returns an iterator over immediate children modules.
r
"""Returns an iterator over immediate children modules.
Yields:
Yields:
Module: a child module
Module: a child module
(can be InfiniCoreModule or torch.nn.Module)
"""
"""
for
name
,
module
in
self
.
named_children
():
for
name
,
module
in
self
.
named_children
():
yield
module
yield
module
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
'InfiniCoreModule'
]]:
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
Union
[
'InfiniCoreModule'
,
torch
.
nn
.
Module
]
]]:
r
"""Returns an iterator over immediate children modules, yielding both
r
"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
the name of the module as well as the module itself.
...
...
python/infinicore/nn/modules/module_list.py
0 → 100644
View file @
cad51297
# Copyright (c) 2025, InfiniCore
#
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes.
from
typing
import
List
,
Optional
,
Iterator
,
Union
,
Sequence
,
TypeVar
import
torch
import
operator
from
itertools
import
chain
from
collections
import
OrderedDict
from
.module
import
InfiniCoreModule
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
ModuleType
=
TypeVar
(
'ModuleType'
,
bound
=
Union
[
torch
.
nn
.
Module
,
'InfiniCoreModule'
])
class
InfiniCoreModuleList
(
InfiniCoreModule
):
r
"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
InfiniCoreModule methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
>>> class MyModel(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.linears = InfiniCoreModuleList([
... torch.nn.Linear(10, 10) for i in range(10)
... ])
...
... def forward(self, x):
... # ModuleList can act as an iterable, or be indexed using ints
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x
"""
def
__init__
(
self
,
modules
:
Optional
[
Sequence
[
ModuleType
]]
=
None
):
super
().
__init__
()
if
modules
is
not
None
:
self
+=
modules
def
_get_abs_string_index
(
self
,
idx
):
"""Get the absolute index for the list of modules."""
idx
=
operator
.
index
(
idx
)
if
not
(
-
len
(
self
)
<=
idx
<
len
(
self
)):
raise
IndexError
(
f
"index
{
idx
}
is out of range"
)
if
idx
<
0
:
idx
+=
len
(
self
)
return
str
(
idx
)
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
Union
[
ModuleType
,
'InfiniCoreModuleList'
]:
if
isinstance
(
idx
,
slice
):
return
self
.
__class__
(
list
(
self
.
_modules
.
values
())[
idx
])
else
:
return
self
.
_modules
[
self
.
_get_abs_string_index
(
idx
)]
def
__setitem__
(
self
,
idx
:
int
,
module
:
ModuleType
)
->
None
:
idx
=
self
.
_get_abs_string_index
(
idx
)
# Use add_module to register module
self
.
add_module
(
idx
,
module
)
def
__delitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
None
:
if
isinstance
(
idx
,
slice
):
indices_to_delete
=
list
(
range
(
len
(
self
.
_modules
)))[
idx
]
for
k
in
indices_to_delete
:
if
str
(
k
)
in
self
.
_modules
:
del
self
.
_modules
[
str
(
k
)]
else
:
idx_str
=
self
.
_get_abs_string_index
(
idx
)
if
idx_str
in
self
.
_modules
:
del
self
.
_modules
[
idx_str
]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
if
len
(
self
.
_modules
)
>
0
:
str_indices
=
[
str
(
i
)
for
i
in
range
(
len
(
self
.
_modules
))]
self
.
_modules
=
OrderedDict
(
list
(
zip
(
str_indices
,
self
.
_modules
.
values
())))
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_modules
)
def
__iter__
(
self
)
->
Iterator
[
ModuleType
]:
return
iter
(
self
.
_modules
.
values
())
def
__iadd__
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'InfiniCoreModuleList'
:
return
self
.
extend
(
modules
)
def
__add__
(
self
,
other
:
Union
[
Sequence
[
ModuleType
],
'InfiniCoreModuleList'
])
->
'InfiniCoreModuleList'
:
r
"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args:
other (iterable): iterable of modules to concatenate
"""
if
not
isinstance
(
other
,
(
list
,
tuple
,
InfiniCoreModuleList
)):
raise
TypeError
(
f
"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f
"got
{
type
(
other
).
__name__
}
"
)
combined
=
InfiniCoreModuleList
()
for
i
,
module
in
enumerate
(
chain
(
self
,
other
)):
combined
.
add_module
(
str
(
i
),
module
)
return
combined
def
append
(
self
,
module
:
ModuleType
)
->
'InfiniCoreModuleList'
:
r
"""Append a given module to the end of the list.
Args:
module (nn.Module or InfiniCoreModule): module to append
"""
self
.
add_module
(
str
(
len
(
self
)),
module
)
return
self
def
extend
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'InfiniCoreModuleList'
:
r
"""Append modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if
not
isinstance
(
modules
,
(
list
,
tuple
)):
try
:
modules
=
list
(
modules
)
except
TypeError
:
raise
TypeError
(
f
"InfiniCoreModuleList.extend should be called with an "
f
"iterable, but got
{
type
(
modules
).
__name__
}
"
)
offset
=
len
(
self
)
for
i
,
module
in
enumerate
(
modules
):
self
.
add_module
(
str
(
offset
+
i
),
module
)
return
self
def
insert
(
self
,
index
:
int
,
module
:
ModuleType
)
->
None
:
r
"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module (nn.Module or InfiniCoreModule): module to insert
"""
for
i
in
range
(
len
(
self
.
_modules
),
index
,
-
1
):
self
.
_modules
[
str
(
i
)]
=
self
.
_modules
[
str
(
i
-
1
)]
self
.
_modules
[
str
(
index
)]
=
module
def
pop
(
self
,
idx
:
int
=
-
1
)
->
ModuleType
:
r
"""Remove and return a module at the given index.
Args:
idx (int): index of the module to pop. Default: -1 (last module)
Returns:
Module: the module that was removed
"""
idx_str
=
self
.
_get_abs_string_index
(
idx
)
module
=
self
.
_modules
[
idx_str
]
# Use __delitem__ to ensure proper cleanup
self
.
__delitem__
(
int
(
idx_str
))
return
module
def
__repr__
(
self
)
->
str
:
"""Return a string representation of the ModuleList."""
if
len
(
self
)
==
0
:
return
self
.
__class__
.
__name__
+
"()"
lines
=
[]
for
i
,
module
in
enumerate
(
self
):
lines
.
append
(
f
"(
{
i
}
):
{
repr
(
module
)
}
"
)
main_str
=
self
.
__class__
.
__name__
+
"(
\n
"
main_str
+=
"
\n
"
.
join
(
lines
)
+
"
\n
)"
return
main_str
def
__dir__
(
self
)
->
List
[
str
]:
"""Return a list of attribute names, excluding numeric keys."""
keys
=
super
().
__dir__
()
# Filter out numeric keys to avoid cluttering dir() output
keys
=
[
key
for
key
in
keys
if
not
key
.
isdigit
()]
return
keys
test/infinicore/infinicore_module_list_test.py
0 → 100644
View file @
cad51297
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
# 使用临时目录,如果不存在则自动创建
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_modulelist_with_param.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class
TorchModuleListNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 torch.nn.ModuleList
self
.
layers
=
nn
.
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchModuleListNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
print
(
"✓ PyTorch 模型已保存"
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchModuleListNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
print
(
"✓ Torch 输出:"
,
torch_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from
nn.modules
import
Module
,
ModuleList
class
InfiniCoreModuleListNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 ModuleList
self
.
layers
=
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreModuleListNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
print
(
"✓ InfiniCore 输出:"
,
infinicore_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 4. 对比结果
# ============================================================
diff
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"✓ ModuleList 与 Torch 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ ModuleList 与 Torch 精度一致."
)
else
:
print
(
"✗ ModuleList 与 Torch 精度存在差异."
)
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print
(
"
\n
=== 测试 ModuleList 基本功能 ==="
)
# 测试 1: 创建和访问
module_list
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
print
(
f
"✓ 创建 ModuleList,长度:
{
len
(
module_list
)
}
"
)
print
(
f
"✓ 访问第一个模块:
{
type
(
module_list
[
0
]).
__name__
}
"
)
print
(
f
"✓ 访问第二个模块:
{
type
(
module_list
[
1
]).
__name__
}
"
)
# 测试 2: append
module_list
.
append
(
nn
.
Softmax
(
dim
=-
1
))
print
(
f
"✓ append 后长度:
{
len
(
module_list
)
}
"
)
# 测试 3: extend
module_list
.
extend
([
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
5
,
1
)])
print
(
f
"✓ extend 后长度:
{
len
(
module_list
)
}
"
)
# 测试 4: 迭代
print
(
"✓ 迭代 ModuleList:"
)
for
i
,
module
in
enumerate
(
module_list
):
print
(
f
" [
{
i
}
]
{
type
(
module
).
__name__
}
"
)
# 测试 5: 索引访问
print
(
f
"✓ 索引访问 module_list[0]:
{
type
(
module_list
[
0
]).
__name__
}
"
)
# 测试 6: state_dict
state_dict
=
module_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
print
(
f
"✓ state_dict 包含模块参数:
{
any
(
'0.'
in
k
for
k
in
state_dict
.
keys
())
}
"
)
# 测试 7: 使用 ModuleList 的模型
class
TestNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
test_model
=
TestNet
()
test_input
=
torch
.
randn
(
2
,
10
)
test_output
=
test_model
.
forward
(
test_input
)
print
(
f
"✓ TestNet 输入形状:
{
test_input
.
shape
}
, 输出形状:
{
test_output
.
shape
}
"
)
# 测试 8: __add__ 方法
ml1
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
ml2
=
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Sigmoid
()])
ml3
=
ml1
+
ml2
print
(
f
"✓ __add__ 方法测试:
{
len
(
ml1
)
}
+
{
len
(
ml2
)
}
=
{
len
(
ml3
)
}
"
)
assert
len
(
ml3
)
==
4
,
"合并后的长度应该为 4"
# 测试 9: pop 方法
ml4
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
(),
nn
.
Linear
(
5
,
3
)])
popped
=
ml4
.
pop
()
print
(
f
"✓ pop 方法测试: 弹出后长度
{
len
(
ml4
)
}
, 弹出模块类型
{
type
(
popped
).
__name__
}
"
)
assert
len
(
ml4
)
==
2
,
"pop 后长度应该为 2"
assert
isinstance
(
popped
,
nn
.
Linear
),
"弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
repr_str
=
repr
(
ml5
)
print
(
f
"✓ __repr__ 方法测试: 输出包含类名和模块信息"
)
assert
"ModuleList"
in
repr_str
or
"InfiniCoreModuleList"
in
repr_str
,
"repr 应该包含类名"
assert
"Linear"
in
repr_str
,
"repr 应该包含模块信息"
print
(
repr_str
)
print
(
"
\n
=== 所有测试通过! ==="
)
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print
(
"
\n
=== 前向传播集成测试 ==="
)
# 使用 ModuleList 创建一个简单的模型
class
TorchModuleListModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
class
InfiniCoreModuleListModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# 创建模型
torch_model_forward
=
TorchModuleListModel
()
infinicore_model_forward
=
InfiniCoreModuleListModel
()
# 复制权重(确保初始权重一致)
infinicore_model_forward
.
load_state_dict
(
torch_model_forward
.
state_dict
(),
strict
=
False
)
# 设置为评估模式
torch_model_forward
.
eval
()
infinicore_model_forward
.
eval
()
# 创建测试输入
test_input
=
torch
.
randn
(
2
,
10
)
# 前向传播
with
torch
.
no_grad
():
torch_output
=
torch_model_forward
(
test_input
)
infinicore_output
=
infinicore_model_forward
.
forward
(
test_input
)
# 对比结果
diff
=
(
infinicore_output
-
torch_output
).
abs
().
max
().
item
()
print
(
f
"✓ 前向传播测试 - 输入形状:
{
test_input
.
shape
}
"
)
print
(
f
"✓ Torch 输出形状:
{
torch_output
.
shape
}
, 均值:
{
torch_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ InfiniCore 输出形状:
{
infinicore_output
.
shape
}
, 均值:
{
infinicore_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!"
)
else
:
print
(
"✗ 前向传播集成测试失败:存在差异"
)
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print
(
"
\n
=== 混合模块兼容性测试 ==="
)
# 创建一个自定义的 InfiniCore 模块
class
CustomLinear
(
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
randn
(
out_features
))
def
forward
(
self
,
x
):
return
x
@
self
.
weight
.
t
()
+
self
.
bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
# PyTorch 模块
CustomLinear
(
5
,
3
),
# 自定义 InfiniCore 模块
nn
.
ReLU
(),
# PyTorch 模块
])
print
(
f
"✓ 创建混合 ModuleList,长度:
{
len
(
mixed_list
)
}
"
)
print
(
f
"✓ 模块类型:
{
[
type
(
m
).
__name__
for
m
in
mixed_list
]
}
"
)
# 测试参数注册
param_count
=
sum
(
1
for
_
in
mixed_list
.
parameters
())
print
(
f
"✓ 参数数量:
{
param_count
}
"
)
assert
param_count
==
4
,
f
"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为
{
param_count
}
"
# 测试 state_dict
mixed_state_dict
=
mixed_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
mixed_state_dict
)
}
"
)
assert
len
(
mixed_state_dict
)
>=
4
,
"state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed
=
torch
.
randn
(
2
,
10
)
with
torch
.
no_grad
():
x
=
test_input_mixed
for
module
in
mixed_list
:
x
=
module
.
forward
(
x
)
print
(
f
"✓ 混合模块前向传播成功,输出形状:
{
x
.
shape
}
"
)
print
(
"✓ 混合模块兼容性测试通过!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment