Commit cad51297 authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

Issue/568: feat: add infinicore.nn.InfiniCoreModuleList referencing torch.nn.ModuleList.

           add some functions in InfiniCoreModule.
parent 27e57f3d
from .module import InfiniCoreModule as Module
from .module_list import InfiniCoreModuleList as ModuleList
......@@ -105,7 +105,7 @@ class InfiniCoreModule:
self.register_parameter(name, value)
else:
modules = self.__dict__.get("_modules")
if isinstance(value, (torch.nn.Module)):
if isinstance(value, (torch.nn.Module, InfiniCoreModule)):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
......@@ -181,6 +181,33 @@ class InfiniCoreModule:
self._non_persistent_buffers_set.add(name)
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
r"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module or None): child module to be added to the module. If
``None``, then operations that run on modules, such as :attr:`eval`,
are ignored. If ``None``, the module is **not** included in the
module's :attr:`children`.
"""
if not isinstance(name, str):
raise TypeError(f"module name should be a string. Got {torch.typename(name)}")
elif '.' in name:
raise KeyError(f"module name can't contain \".\", got: {name}")
elif name == '':
raise KeyError("module name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._modules:
raise KeyError(f"attribute '{name}' already exists")
if module is not None and not isinstance(module, (torch.nn.Module, InfiniCoreModule)):
raise TypeError(f"{torch.typename(module)} is not a Module subclass")
self._modules[name] = module
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
r"""Add a parameter to the module.
......@@ -526,16 +553,212 @@ class InfiniCoreModule:
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def children(self) -> Iterator['InfiniCoreModule']:
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
r"""Returns an iterator over module parameters.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
... print(type(param), param.size())
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
... if name in ['bias']:
... print(param.size())
"""
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
r"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
... print(type(buf), buf.size())
"""
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
... if name in ['running_mean']:
... print(buf.size())
"""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
for k, v in module._buffers.items():
if v is None or v in memo:
continue
if k in module._non_persistent_buffers_set:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield (name, v)
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method to yield members with their names."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield (name, v)
def modules(self) -> Iterator['InfiniCoreModule']:
r"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for name, module in self.named_modules():
yield module
def named_modules(self, memo: Optional[Set['InfiniCoreModule']] = None, prefix: str = '', remove_duplicate: bool = True):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if memo is None:
memo = set()
if remove_duplicate:
if self in memo:
return
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
# Handle both InfiniCoreModule and torch.nn.Module
if isinstance(module, InfiniCoreModule):
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
elif isinstance(module, torch.nn.Module):
# For torch.nn.Module, use its named_modules method
# torch.nn.Module.named_modules returns (name, module) tuples
for sub_name, sub_module in module.named_modules(prefix=submodule_prefix, remove_duplicate=remove_duplicate):
yield (sub_name, sub_module)
def children(self) -> Iterator[Union['InfiniCoreModule', torch.nn.Module]]:
r"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
Module: a child module (can be InfiniCoreModule or torch.nn.Module)
"""
for name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, 'InfiniCoreModule']]:
def named_children(self) -> Iterator[Tuple[str, Union['InfiniCoreModule', torch.nn.Module]]]:
r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
......
# Copyright (c) 2025, InfiniCore
#
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
# but based on InfiniCoreModule for inference purposes.
from typing import List, Optional, Iterator, Union, Sequence, TypeVar
import torch
import operator
from itertools import chain
from collections import OrderedDict
from .module import InfiniCoreModule
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
ModuleType = TypeVar('ModuleType', bound=Union[torch.nn.Module, 'InfiniCoreModule'])
class InfiniCoreModuleList(InfiniCoreModule):
r"""Holds submodules in a list.
InfiniCoreModuleList can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
InfiniCoreModule methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
>>> class MyModel(InfiniCoreModule):
... def __init__(self):
... super().__init__()
... self.linears = InfiniCoreModuleList([
... torch.nn.Linear(10, 10) for i in range(10)
... ])
...
... def forward(self, x):
... # ModuleList can act as an iterable, or be indexed using ints
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x
"""
def __init__(self, modules: Optional[Sequence[ModuleType]] = None):
super().__init__()
if modules is not None:
self += modules
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f"index {idx} is out of range")
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(self, idx: Union[int, slice]) -> Union[ModuleType, 'InfiniCoreModuleList']:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
def __setitem__(self, idx: int, module: ModuleType) -> None:
idx = self._get_abs_string_index(idx)
# Use add_module to register module
self.add_module(idx, module)
def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
indices_to_delete = list(range(len(self._modules)))[idx]
for k in indices_to_delete:
if str(k) in self._modules:
del self._modules[str(k)]
else:
idx_str = self._get_abs_string_index(idx)
if idx_str in self._modules:
del self._modules[idx_str]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
if len(self._modules) > 0:
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[ModuleType]:
return iter(self._modules.values())
def __iadd__(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
return self.extend(modules)
def __add__(self, other: Union[Sequence[ModuleType], 'InfiniCoreModuleList']) -> 'InfiniCoreModuleList':
r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
Args:
other (iterable): iterable of modules to concatenate
"""
if not isinstance(other, (list, tuple, InfiniCoreModuleList)):
raise TypeError(
f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
f"got {type(other).__name__}"
)
combined = InfiniCoreModuleList()
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined
def append(self, module: ModuleType) -> 'InfiniCoreModuleList':
r"""Append a given module to the end of the list.
Args:
module (nn.Module or InfiniCoreModule): module to append
"""
self.add_module(str(len(self)), module)
return self
def extend(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
r"""Append modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, (list, tuple)):
try:
modules = list(modules)
except TypeError:
raise TypeError(
f"InfiniCoreModuleList.extend should be called with an "
f"iterable, but got {type(modules).__name__}"
)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
def insert(self, index: int, module: ModuleType) -> None:
r"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module (nn.Module or InfiniCoreModule): module to insert
"""
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
def pop(self, idx: int = -1) -> ModuleType:
r"""Remove and return a module at the given index.
Args:
idx (int): index of the module to pop. Default: -1 (last module)
Returns:
Module: the module that was removed
"""
idx_str = self._get_abs_string_index(idx)
module = self._modules[idx_str]
# Use __delitem__ to ensure proper cleanup
self.__delitem__(int(idx_str))
return module
def __repr__(self) -> str:
"""Return a string representation of the ModuleList."""
if len(self) == 0:
return self.__class__.__name__ + "()"
lines = []
for i, module in enumerate(self):
lines.append(f"({i}): {repr(module)}")
main_str = self.__class__.__name__ + "(\n "
main_str += "\n ".join(lines) + "\n)"
return main_str
def __dir__(self) -> List[str]:
"""Return a list of attribute names, excluding numeric keys."""
keys = super().__dir__()
# Filter out numeric keys to avoid cluttering dir() output
keys = [key for key in keys if not key.isdigit()]
return keys
import safetensors.torch
import torch
import torch.nn as nn
import safetensors
# ============================================================
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
# ============================================================
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
# 使用临时目录,如果不存在则自动创建
save_dir = os.path.join(os.path.dirname(__file__), '../../tmp')
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "torch_modulelist_with_param.safetensors")
# ============================================================
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList)
# ============================================================
class TorchModuleListNet(nn.Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 torch.nn.ModuleList
self.layers = nn.ModuleList([
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
])
# 自定义 Parameter
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
# 应用自定义参数和 buffer
x = x * self.scale + self.offset
return x
# ===== 保存 Torch 模型 =====
torch_model = TorchModuleListNet()
torch_state_dict = torch_model.state_dict()
safetensors.torch.save_file(torch_state_dict, save_path)
print("✓ PyTorch 模型已保存")
# ============================================================
# 2. 使用 torch 方式加载并推理
# ============================================================
torch_model_infer = TorchModuleListNet()
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
torch_model_infer.eval()
input = torch.rand(1, 3, 8, 8)
torch_model_out = torch_model_infer(input)
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
# ============================================================
# 3. 使用 ModuleList 加载并推理
# ============================================================
from nn.modules import Module, ModuleList
class InfiniCoreModuleListNet(Module):
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
super().__init__()
# 使用 ModuleList
self.layers = ModuleList([
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_ch),
nn.ReLU(),
nn.Conv2d(hidden_ch, out_ch, kernel_size=1),
])
# 保持与 Torch 模型一致的自定义参数和 buffer
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
# 遍历 ModuleList 中的所有层
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# ===== 使用 ModuleListNet 读取 safetensors 并推理 =====
infinicore_model_infer = InfiniCoreModuleListNet()
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
infinicore_model_infer.eval()
infinicore_model_out = infinicore_model_infer.forward(input)
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
# ============================================================
# 4. 对比结果
# ============================================================
diff = (infinicore_model_out - torch_model_out).abs().max().item()
print(f"✓ ModuleList 与 Torch 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ ModuleList 与 Torch 精度一致.")
else:
print("✗ ModuleList 与 Torch 精度存在差异.")
# ============================================================
# 5. 测试 ModuleList 的基本功能
# ============================================================
print("\n=== 测试 ModuleList 基本功能 ===")
# 测试 1: 创建和访问
module_list = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
print(f"✓ 创建 ModuleList,长度: {len(module_list)}")
print(f"✓ 访问第一个模块: {type(module_list[0]).__name__}")
print(f"✓ 访问第二个模块: {type(module_list[1]).__name__}")
# 测试 2: append
module_list.append(nn.Softmax(dim=-1))
print(f"✓ append 后长度: {len(module_list)}")
# 测试 3: extend
module_list.extend([nn.Dropout(0.1), nn.Linear(5, 1)])
print(f"✓ extend 后长度: {len(module_list)}")
# 测试 4: 迭代
print("✓ 迭代 ModuleList:")
for i, module in enumerate(module_list):
print(f" [{i}] {type(module).__name__}")
# 测试 5: 索引访问
print(f"✓ 索引访问 module_list[0]: {type(module_list[0]).__name__}")
# 测试 6: state_dict
state_dict = module_list.state_dict()
print(f"✓ state_dict 键数量: {len(state_dict)}")
print(f"✓ state_dict 包含模块参数: {any('0.' in k for k in state_dict.keys())}")
# 测试 7: 使用 ModuleList 的模型
class TestNet(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
test_model = TestNet()
test_input = torch.randn(2, 10)
test_output = test_model.forward(test_input)
print(f"✓ TestNet 输入形状: {test_input.shape}, 输出形状: {test_output.shape}")
# 测试 8: __add__ 方法
ml1 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
ml2 = ModuleList([nn.Linear(5, 3), nn.Sigmoid()])
ml3 = ml1 + ml2
print(f"✓ __add__ 方法测试: {len(ml1)} + {len(ml2)} = {len(ml3)}")
assert len(ml3) == 4, "合并后的长度应该为 4"
# 测试 9: pop 方法
ml4 = ModuleList([nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3)])
popped = ml4.pop()
print(f"✓ pop 方法测试: 弹出后长度 {len(ml4)}, 弹出模块类型 {type(popped).__name__}")
assert len(ml4) == 2, "pop 后长度应该为 2"
assert isinstance(popped, nn.Linear), "弹出的应该是 Linear 模块"
# 测试 10: __repr__ 方法
ml5 = ModuleList([nn.Linear(10, 5), nn.ReLU()])
repr_str = repr(ml5)
print(f"✓ __repr__ 方法测试: 输出包含类名和模块信息")
assert "ModuleList" in repr_str or "InfiniCoreModuleList" in repr_str, "repr 应该包含类名"
assert "Linear" in repr_str, "repr 应该包含模块信息"
print(repr_str)
print("\n=== 所有测试通过! ===")
# ============================================================
# 6. 前向传播集成测试(参考 infinicore_nn_test.py)
# ============================================================
print("\n=== 前向传播集成测试 ===")
# 使用 ModuleList 创建一个简单的模型
class TorchModuleListModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
class InfiniCoreModuleListModel(Module):
def __init__(self):
super().__init__()
self.layers = ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
])
self.scale = nn.Parameter(torch.ones(1) * 0.5)
self.register_buffer("offset", torch.tensor(0.1))
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = x * self.scale + self.offset
return x
# 创建模型
torch_model_forward = TorchModuleListModel()
infinicore_model_forward = InfiniCoreModuleListModel()
# 复制权重(确保初始权重一致)
infinicore_model_forward.load_state_dict(torch_model_forward.state_dict(), strict=False)
# 设置为评估模式
torch_model_forward.eval()
infinicore_model_forward.eval()
# 创建测试输入
test_input = torch.randn(2, 10)
# 前向传播
with torch.no_grad():
torch_output = torch_model_forward(test_input)
infinicore_output = infinicore_model_forward.forward(test_input)
# 对比结果
diff = (infinicore_output - torch_output).abs().max().item()
print(f"✓ 前向传播测试 - 输入形状: {test_input.shape}")
print(f"✓ Torch 输出形状: {torch_output.shape}, 均值: {torch_output.detach().numpy().mean():.8f}")
print(f"✓ InfiniCore 输出形状: {infinicore_output.shape}, 均值: {infinicore_output.detach().numpy().mean():.8f}")
print(f"✓ 最大误差: {diff:.8f}")
if diff < 1e-9:
print("✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!")
else:
print("✗ 前向传播集成测试失败:存在差异")
# ============================================================
# 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用)
# ============================================================
print("\n=== 混合模块兼容性测试 ===")
# 创建一个自定义的 InfiniCore 模块
class CustomLinear(Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
return x @ self.weight.t() + self.bias
# 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块)
mixed_list = ModuleList([
nn.Linear(10, 5), # PyTorch 模块
CustomLinear(5, 3), # 自定义 InfiniCore 模块
nn.ReLU(), # PyTorch 模块
])
print(f"✓ 创建混合 ModuleList,长度: {len(mixed_list)}")
print(f"✓ 模块类型: {[type(m).__name__ for m in mixed_list]}")
# 测试参数注册
param_count = sum(1 for _ in mixed_list.parameters())
print(f"✓ 参数数量: {param_count}")
assert param_count == 4, f"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为 {param_count}"
# 测试 state_dict
mixed_state_dict = mixed_list.state_dict()
print(f"✓ state_dict 键数量: {len(mixed_state_dict)}")
assert len(mixed_state_dict) >= 4, "state_dict 应该包含至少 4 个参数"
# 测试前向传播
test_input_mixed = torch.randn(2, 10)
with torch.no_grad():
x = test_input_mixed
for module in mixed_list:
x = module.forward(x)
print(f"✓ 混合模块前向传播成功,输出形状: {x.shape}")
print("✓ 混合模块兼容性测试通过!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment