Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
cad51297
Commit
cad51297
authored
Nov 06, 2025
by
zhuyue
Committed by
zhuyue
Nov 17, 2025
Browse files
Issue/568: feat: add infinicore.nn.InfiniCoreModuleList referencing torch.nn.ModuleList.
add some functions in InfiniCoreModule.
parent
27e57f3d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
728 additions
and
4 deletions
+728
-4
python/infinicore/nn/modules/__init__.py
python/infinicore/nn/modules/__init__.py
+1
-0
python/infinicore/nn/modules/module.py
python/infinicore/nn/modules/module.py
+227
-4
python/infinicore/nn/modules/module_list.py
python/infinicore/nn/modules/module_list.py
+183
-0
test/infinicore/infinicore_module_list_test.py
test/infinicore/infinicore_module_list_test.py
+317
-0
No files found.
python/infinicore/nn/modules/__init__.py
View file @
cad51297
from
.module
import
InfiniCoreModule
as
Module
from
.module_list
import
InfiniCoreModuleList
as
ModuleList
python/infinicore/nn/modules/module.py
View file @
cad51297
...
...
@@ -105,7 +105,7 @@ class InfiniCoreModule:
self
.
register_parameter
(
name
,
value
)
else
:
modules
=
self
.
__dict__
.
get
(
"_modules"
)
if
isinstance
(
value
,
(
torch
.
nn
.
Module
)):
if
isinstance
(
value
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)):
if
modules
is
None
:
raise
AttributeError
(
"cannot assign module before Module.__init__() call"
...
...
@@ -181,6 +181,33 @@ class InfiniCoreModule:
self
.
_non_persistent_buffers_set
.
add
(
name
)
def
add_module
(
self
,
name
:
str
,
module
:
Optional
[
torch
.
nn
.
Module
])
->
None
:
r
"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module or None): child module to be added to the module. If
``None``, then operations that run on modules, such as :attr:`eval`,
are ignored. If ``None``, the module is **not** included in the
module's :attr:`children`.
"""
if
not
isinstance
(
name
,
str
):
raise
TypeError
(
f
"module name should be a string. Got
{
torch
.
typename
(
name
)
}
"
)
elif
'.'
in
name
:
raise
KeyError
(
f
"module name can't contain
\"
.
\"
, got:
{
name
}
"
)
elif
name
==
''
:
raise
KeyError
(
"module name can't be empty string
\"\"
"
)
elif
hasattr
(
self
,
name
)
and
name
not
in
self
.
_modules
:
raise
KeyError
(
f
"attribute '
{
name
}
' already exists"
)
if
module
is
not
None
and
not
isinstance
(
module
,
(
torch
.
nn
.
Module
,
InfiniCoreModule
)):
raise
TypeError
(
f
"
{
torch
.
typename
(
module
)
}
is not a Module subclass"
)
self
.
_modules
[
name
]
=
module
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""Add a parameter to the module.
...
...
@@ -526,16 +553,212 @@ class InfiniCoreModule:
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
def
children
(
self
)
->
Iterator
[
'InfiniCoreModule'
]:
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
nn
.
Parameter
]:
r
"""Returns an iterator over module parameters.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
... print(type(param), param.size())
"""
for
name
,
param
in
self
.
named_parameters
(
recurse
=
recurse
):
yield
param
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
nn
.
Parameter
]]:
r
"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
... if name in ['bias']:
... print(param.size())
"""
gen
=
self
.
_named_members
(
lambda
module
:
module
.
_parameters
.
items
(),
prefix
=
prefix
,
recurse
=
recurse
)
for
elem
in
gen
:
yield
elem
def
buffers
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
Tensor
]:
r
"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
... print(type(buf), buf.size())
"""
for
name
,
buf
in
self
.
named_buffers
(
recurse
=
recurse
):
yield
buf
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
r
"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
... if name in ['running_mean']:
... print(buf.size())
"""
memo
=
set
()
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
for
module_prefix
,
module
in
modules
:
for
k
,
v
in
module
.
_buffers
.
items
():
if
v
is
None
or
v
in
memo
:
continue
if
k
in
module
.
_non_persistent_buffers_set
:
continue
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
yield
(
name
,
v
)
def
_named_members
(
self
,
get_members_fn
,
prefix
=
''
,
recurse
=
True
):
r
"""Helper method to yield members with their names."""
memo
=
set
()
modules
=
self
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
self
)]
for
module_prefix
,
module
in
modules
:
members
=
get_members_fn
(
module
)
for
k
,
v
in
members
:
if
v
is
None
or
v
in
memo
:
continue
memo
.
add
(
v
)
name
=
module_prefix
+
(
'.'
if
module_prefix
else
''
)
+
k
yield
(
name
,
v
)
def
modules
(
self
)
->
Iterator
[
'InfiniCoreModule'
]:
r
"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for
name
,
module
in
self
.
named_modules
():
yield
module
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
'InfiniCoreModule'
]]
=
None
,
prefix
:
str
=
''
,
remove_duplicate
:
bool
=
True
):
r
"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if
memo
is
None
:
memo
=
set
()
if
remove_duplicate
:
if
self
in
memo
:
return
memo
.
add
(
self
)
yield
prefix
,
self
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
None
:
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
name
# Handle both InfiniCoreModule and torch.nn.Module
if
isinstance
(
module
,
InfiniCoreModule
):
for
m
in
module
.
named_modules
(
memo
,
submodule_prefix
,
remove_duplicate
):
yield
m
elif
isinstance
(
module
,
torch
.
nn
.
Module
):
# For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples
for
sub_name
,
sub_module
in
module
.
named_modules
(
prefix
=
submodule_prefix
,
remove_duplicate
=
remove_duplicate
):
yield
(
sub_name
,
sub_module
)
def
children
(
self
)
->
Iterator
[
Union
[
'InfiniCoreModule'
,
torch
.
nn
.
Module
]]:
r
"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
Module: a child module
(can be InfiniCoreModule or torch.nn.Module)
"""
for
name
,
module
in
self
.
named_children
():
yield
module
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
'InfiniCoreModule'
]]:
def
named_children
(
self
)
->
Iterator
[
Tuple
[
str
,
Union
[
'InfiniCoreModule'
,
torch
.
nn
.
Module
]
]]:
r
"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
...
...
python/infinicore/nn/modules/module_list.py
0 → 100644
View file @
cad51297
# Copyright (c) 2025, InfiniCore
#
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes.
from
typing
import
List
,
Optional
,
Iterator
,
Union
,
Sequence
,
TypeVar
import
torch
import
operator
from
itertools
import
chain
from
collections
import
OrderedDict
from
.module
import
InfiniCoreModule
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
ModuleType
=
TypeVar
(
'ModuleType'
,
bound
=
Union
[
torch
.
nn
.
Module
,
'InfiniCoreModule'
])
class
InfiniCoreModuleList
(
InfiniCoreModule
):
r
"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
InfiniCoreModule methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
>>> class MyModel(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.linears = InfiniCoreModuleList([
... torch.nn.Linear(10, 10) for i in range(10)
... ])
...
... def forward(self, x):
... # ModuleList can act as an iterable, or be indexed using ints
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x
"""
def
__init__
(
self
,
modules
:
Optional
[
Sequence
[
ModuleType
]]
=
None
):
super
().
__init__
()
if
modules
is
not
None
:
self
+=
modules
def
_get_abs_string_index
(
self
,
idx
):
"""Get the absolute index for the list of modules."""
idx
=
operator
.
index
(
idx
)
if
not
(
-
len
(
self
)
<=
idx
<
len
(
self
)):
raise
IndexError
(
f
"index
{
idx
}
is out of range"
)
if
idx
<
0
:
idx
+=
len
(
self
)
return
str
(
idx
)
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
Union
[
ModuleType
,
'InfiniCoreModuleList'
]:
if
isinstance
(
idx
,
slice
):
return
self
.
__class__
(
list
(
self
.
_modules
.
values
())[
idx
])
else
:
return
self
.
_modules
[
self
.
_get_abs_string_index
(
idx
)]
def
__setitem__
(
self
,
idx
:
int
,
module
:
ModuleType
)
->
None
:
idx
=
self
.
_get_abs_string_index
(
idx
)
# Use add_module to register module
self
.
add_module
(
idx
,
module
)
def
__delitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
None
:
if
isinstance
(
idx
,
slice
):
indices_to_delete
=
list
(
range
(
len
(
self
.
_modules
)))[
idx
]
for
k
in
indices_to_delete
:
if
str
(
k
)
in
self
.
_modules
:
del
self
.
_modules
[
str
(
k
)]
else
:
idx_str
=
self
.
_get_abs_string_index
(
idx
)
if
idx_str
in
self
.
_modules
:
del
self
.
_modules
[
idx_str
]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
if
len
(
self
.
_modules
)
>
0
:
str_indices
=
[
str
(
i
)
for
i
in
range
(
len
(
self
.
_modules
))]
self
.
_modules
=
OrderedDict
(
list
(
zip
(
str_indices
,
self
.
_modules
.
values
())))
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_modules
)
def
__iter__
(
self
)
->
Iterator
[
ModuleType
]:
return
iter
(
self
.
_modules
.
values
())
def
__iadd__
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'InfiniCoreModuleList'
:
return
self
.
extend
(
modules
)
def
__add__
(
self
,
other
:
Union
[
Sequence
[
ModuleType
],
'InfiniCoreModuleList'
])
->
'InfiniCoreModuleList'
:
r
"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args:
other (iterable): iterable of modules to concatenate
"""
if
not
isinstance
(
other
,
(
list
,
tuple
,
InfiniCoreModuleList
)):
raise
TypeError
(
f
"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f
"got
{
type
(
other
).
__name__
}
"
)
combined
=
InfiniCoreModuleList
()
for
i
,
module
in
enumerate
(
chain
(
self
,
other
)):
combined
.
add_module
(
str
(
i
),
module
)
return
combined
def
append
(
self
,
module
:
ModuleType
)
->
'InfiniCoreModuleList'
:
r
"""Append a given module to the end of the list.
Args:
module (nn.Module or InfiniCoreModule): module to append
"""
self
.
add_module
(
str
(
len
(
self
)),
module
)
return
self
def
extend
(
self
,
modules
:
Sequence
[
ModuleType
])
->
'InfiniCoreModuleList'
:
r
"""Append modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if
not
isinstance
(
modules
,
(
list
,
tuple
)):
try
:
modules
=
list
(
modules
)
except
TypeError
:
raise
TypeError
(
f
"InfiniCoreModuleList.extend should be called with an "
f
"iterable, but got
{
type
(
modules
).
__name__
}
"
)
offset
=
len
(
self
)
for
i
,
module
in
enumerate
(
modules
):
self
.
add_module
(
str
(
offset
+
i
),
module
)
return
self
def
insert
(
self
,
index
:
int
,
module
:
ModuleType
)
->
None
:
r
"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module (nn.Module or InfiniCoreModule): module to insert
"""
for
i
in
range
(
len
(
self
.
_modules
),
index
,
-
1
):
self
.
_modules
[
str
(
i
)]
=
self
.
_modules
[
str
(
i
-
1
)]
self
.
_modules
[
str
(
index
)]
=
module
def
pop
(
self
,
idx
:
int
=
-
1
)
->
ModuleType
:
r
"""Remove and return a module at the given index.
Args:
idx (int): index of the module to pop. Default: -1 (last module)
Returns:
Module: the module that was removed
"""
idx_str
=
self
.
_get_abs_string_index
(
idx
)
module
=
self
.
_modules
[
idx_str
]
# Use __delitem__ to ensure proper cleanup
self
.
__delitem__
(
int
(
idx_str
))
return
module
def
__repr__
(
self
)
->
str
:
"""Return a string representation of the ModuleList."""
if
len
(
self
)
==
0
:
return
self
.
__class__
.
__name__
+
"()"
lines
=
[]
for
i
,
module
in
enumerate
(
self
):
lines
.
append
(
f
"(
{
i
}
):
{
repr
(
module
)
}
"
)
main_str
=
self
.
__class__
.
__name__
+
"(
\n
"
main_str
+=
"
\n
"
.
join
(
lines
)
+
"
\n
)"
return
main_str
def
__dir__
(
self
)
->
List
[
str
]:
"""Return a list of attribute names, excluding numeric keys."""
keys
=
super
().
__dir__
()
# Filter out numeric keys to avoid cluttering dir() output
keys
=
[
key
for
key
in
keys
if
not
key
.
isdigit
()]
return
keys
test/infinicore/infinicore_module_list_test.py
0 → 100644
View file @
cad51297
import
safetensors.torch
import
torch
import
torch.nn
as
nn
import
safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../python/infinicore'
)))
# 使用临时目录,如果不存在则自动创建
save_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../tmp'
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
save_dir
,
"torch_modulelist_with_param.safetensors"
)
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class
TorchModuleListNet
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 torch.nn.ModuleList
self
.
layers
=
nn
.
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 自定义 Parameter
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
# 应用自定义参数和 buffer
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 保存 Torch 模型 =====
torch_model
=
TorchModuleListNet
()
torch_state_dict
=
torch_model
.
state_dict
()
safetensors
.
torch
.
save_file
(
torch_state_dict
,
save_path
)
print
(
"✓ PyTorch 模型已保存"
)
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer
=
TorchModuleListNet
()
torch_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
torch_model_infer
.
eval
()
input
=
torch
.
rand
(
1
,
3
,
8
,
8
)
torch_model_out
=
torch_model_infer
(
input
)
print
(
"✓ Torch 输出:"
,
torch_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from
nn.modules
import
Module
,
ModuleList
class
InfiniCoreModuleListNet
(
Module
):
def
__init__
(
self
,
in_ch
=
3
,
hidden_ch
=
8
,
out_ch
=
3
):
super
().
__init__
()
# 使用 ModuleList
self
.
layers
=
ModuleList
([
nn
.
Conv2d
(
in_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
hidden_ch
,
kernel_size
=
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
hidden_ch
),
nn
.
ReLU
(),
nn
.
Conv2d
(
hidden_ch
,
out_ch
,
kernel_size
=
1
),
])
# 保持与 Torch 模型一致的自定义参数和 buffer
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
# 遍历 ModuleList 中的所有层
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer
=
InfiniCoreModuleListNet
()
infinicore_model_infer
.
load_state_dict
(
safetensors
.
torch
.
load_file
(
save_path
))
infinicore_model_infer
.
eval
()
infinicore_model_out
=
infinicore_model_infer
.
forward
(
input
)
print
(
"✓ InfiniCore 输出:"
,
infinicore_model_out
.
detach
().
numpy
().
mean
())
# ============================================================
# 4. 对比结果
# ============================================================
diff
=
(
infinicore_model_out
-
torch_model_out
).
abs
().
max
().
item
()
print
(
f
"✓ ModuleList 与 Torch 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ ModuleList 与 Torch 精度一致."
)
else
:
print
(
"✗ ModuleList 与 Torch 精度存在差异."
)
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print
(
"
\n
=== 测试 ModuleList 基本功能 ==="
)
# 测试 1: 创建和访问
module_list
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
print
(
f
"✓ 创建 ModuleList,长度:
{
len
(
module_list
)
}
"
)
print
(
f
"✓ 访问第一个模块:
{
type
(
module_list
[
0
]).
__name__
}
"
)
print
(
f
"✓ 访问第二个模块:
{
type
(
module_list
[
1
]).
__name__
}
"
)
# 测试 2: append
module_list
.
append
(
nn
.
Softmax
(
dim
=-
1
))
print
(
f
"✓ append 后长度:
{
len
(
module_list
)
}
"
)
# 测试 3: extend
module_list
.
extend
([
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
5
,
1
)])
print
(
f
"✓ extend 后长度:
{
len
(
module_list
)
}
"
)
# 测试 4: 迭代
print
(
"✓ 迭代 ModuleList:"
)
for
i
,
module
in
enumerate
(
module_list
):
print
(
f
" [
{
i
}
]
{
type
(
module
).
__name__
}
"
)
# 测试 5: 索引访问
print
(
f
"✓ 索引访问 module_list[0]:
{
type
(
module_list
[
0
]).
__name__
}
"
)
# 测试 6: state_dict
state_dict
=
module_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
state_dict
)
}
"
)
print
(
f
"✓ state_dict 包含模块参数:
{
any
(
'0.'
in
k
for
k
in
state_dict
.
keys
())
}
"
)
# 测试 7: 使用 ModuleList 的模型
class
TestNet
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
test_model
=
TestNet
()
test_input
=
torch
.
randn
(
2
,
10
)
test_output
=
test_model
.
forward
(
test_input
)
print
(
f
"✓ TestNet 输入形状:
{
test_input
.
shape
}
, 输出形状:
{
test_output
.
shape
}
"
)
# 测试 8: __add__ 方法
ml1
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
ml2
=
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Sigmoid
()])
ml3
=
ml1
+
ml2
print
(
f
"✓ __add__ 方法测试:
{
len
(
ml1
)
}
+
{
len
(
ml2
)
}
=
{
len
(
ml3
)
}
"
)
assert
len
(
ml3
)
==
4
,
"合并后的长度应该为 4"
# 测试 9: pop 方法
ml4
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
(),
nn
.
Linear
(
5
,
3
)])
popped
=
ml4
.
pop
()
print
(
f
"✓ pop 方法测试: 弹出后长度
{
len
(
ml4
)
}
, 弹出模块类型
{
type
(
popped
).
__name__
}
"
)
assert
len
(
ml4
)
==
2
,
"pop 后长度应该为 2"
assert
isinstance
(
popped
,
nn
.
Linear
),
"弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
nn
.
ReLU
()])
repr_str
=
repr
(
ml5
)
print
(
f
"✓ __repr__ 方法测试: 输出包含类名和模块信息"
)
assert
"ModuleList"
in
repr_str
or
"InfiniCoreModuleList"
in
repr_str
,
"repr 应该包含类名"
assert
"Linear"
in
repr_str
,
"repr 应该包含模块信息"
print
(
repr_str
)
print
(
"
\n
=== 所有测试通过! ==="
)
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print
(
"
\n
=== 前向传播集成测试 ==="
)
# 使用 ModuleList 创建一个简单的模型
class
TorchModuleListModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
class
InfiniCoreModuleListModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layers
=
ModuleList
([
nn
.
Linear
(
10
,
20
),
nn
.
ReLU
(),
nn
.
Linear
(
20
,
5
)
])
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
0.5
)
self
.
register_buffer
(
"offset"
,
torch
.
tensor
(
0.1
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
*
self
.
scale
+
self
.
offset
return
x
# 创建模型
torch_model_forward
=
TorchModuleListModel
()
infinicore_model_forward
=
InfiniCoreModuleListModel
()
# 复制权重(确保初始权重一致)
infinicore_model_forward
.
load_state_dict
(
torch_model_forward
.
state_dict
(),
strict
=
False
)
# 设置为评估模式
torch_model_forward
.
eval
()
infinicore_model_forward
.
eval
()
# 创建测试输入
test_input
=
torch
.
randn
(
2
,
10
)
# 前向传播
with
torch
.
no_grad
():
torch_output
=
torch_model_forward
(
test_input
)
infinicore_output
=
infinicore_model_forward
.
forward
(
test_input
)
# 对比结果
diff
=
(
infinicore_output
-
torch_output
).
abs
().
max
().
item
()
print
(
f
"✓ 前向传播测试 - 输入形状:
{
test_input
.
shape
}
"
)
print
(
f
"✓ Torch 输出形状:
{
torch_output
.
shape
}
, 均值:
{
torch_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ InfiniCore 输出形状:
{
infinicore_output
.
shape
}
, 均值:
{
infinicore_output
.
detach
().
numpy
().
mean
():.
8
f
}
"
)
print
(
f
"✓ 最大误差:
{
diff
:.
8
f
}
"
)
if
diff
<
1e-9
:
print
(
"✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!"
)
else
:
print
(
"✗ 前向传播集成测试失败:存在差异"
)
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print
(
"
\n
=== 混合模块兼容性测试 ==="
)
# 创建一个自定义的 InfiniCore 模块
class
CustomLinear
(
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
randn
(
out_features
))
def
forward
(
self
,
x
):
return
x
@
self
.
weight
.
t
()
+
self
.
bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list
=
ModuleList
([
nn
.
Linear
(
10
,
5
),
# PyTorch 模块
CustomLinear
(
5
,
3
),
# 自定义 InfiniCore 模块
nn
.
ReLU
(),
# PyTorch 模块
])
print
(
f
"✓ 创建混合 ModuleList,长度:
{
len
(
mixed_list
)
}
"
)
print
(
f
"✓ 模块类型:
{
[
type
(
m
).
__name__
for
m
in
mixed_list
]
}
"
)
# 测试参数注册
param_count
=
sum
(
1
for
_
in
mixed_list
.
parameters
())
print
(
f
"✓ 参数数量:
{
param_count
}
"
)
assert
param_count
==
4
,
f
"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为
{
param_count
}
"
# 测试 state_dict
mixed_state_dict
=
mixed_list
.
state_dict
()
print
(
f
"✓ state_dict 键数量:
{
len
(
mixed_state_dict
)
}
"
)
assert
len
(
mixed_state_dict
)
>=
4
,
"state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed
=
torch
.
randn
(
2
,
10
)
with
torch
.
no_grad
():
x
=
test_input_mixed
for
module
in
mixed_list
:
x
=
module
.
forward
(
x
)
print
(
f
"✓ 混合模块前向传播成功,输出形状:
{
x
.
shape
}
"
)
print
(
"✓ 混合模块兼容性测试通过!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment