Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f6107946
Commit
f6107946
authored
Nov 07, 2025
by
zhuyue
Committed by
zhuyue
Nov 17, 2025
Browse files
Issue/568: Support InfiniCoreParam used in InfiniCoreModule.
parent
0b2ea12d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
15 deletions
+26
-15
python/infinicore/nn/modules/module.py
python/infinicore/nn/modules/module.py
+26
-15
No files found.
python/infinicore/nn/modules/module.py
View file @
f6107946
...
...
@@ -16,12 +16,16 @@
from
collections
import
OrderedDict
,
namedtuple
import
itertools
import
warnings
from
typing
import
TYPE_CHECKING
import
torch
from
typing
import
Union
,
Tuple
,
Any
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
,
List
from
torch.utils._python_dispatch
import
is_traceable_wrapper_subclass
if
TYPE_CHECKING
:
from
.parameter
import
InfiniCoreParameter
as
Parameter
_EXTRA_STATE_KEY_SUFFIX
=
'_extra_state'
T
=
TypeVar
(
'T'
,
bound
=
'InfiniCoreModule'
)
...
...
@@ -46,7 +50,7 @@ class InfiniCoreModule:
_version
:
int
=
1
training
:
bool
_parameters
:
Dict
[
str
,
Optional
[
Union
[
torch
.
nn
.
Parameter
,
'
InfiniCore
Parameter'
]]]
_parameters
:
Dict
[
str
,
Optional
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]]]
_buffers
:
Dict
[
str
,
Optional
[
torch
.
Tensor
]]
_non_persistent_buffers_set
:
Set
[
str
]
_modules
:
Dict
[
str
,
Optional
[
'InfiniCoreModule'
]]
...
...
@@ -84,9 +88,9 @@ class InfiniCoreModule:
d
.
discard
(
name
)
params
=
self
.
__dict__
.
get
(
"_parameters"
)
# Support both torch.nn.Parameter and InfiniCoreParameter
from
.parameter
import
InfiniCoreParameter
if
isinstance
(
value
,
(
torch
.
nn
.
Parameter
,
InfiniCore
Parameter
)):
# Support both torch.nn.Parameter and
Parameter (
InfiniCoreParameter
)
from
.parameter
import
InfiniCoreParameter
as
Parameter
if
isinstance
(
value
,
(
torch
.
nn
.
Parameter
,
Parameter
)):
if
params
is
None
:
raise
AttributeError
(
"cannot assign parameters before Module.__init__() call"
...
...
@@ -102,7 +106,7 @@ class InfiniCoreModule:
if
value
is
not
None
:
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
value
)
}
' as parameter '
{
name
}
' "
"(torch.nn.Parameter,
InfiniCore
Parameter or None expected)"
"(torch.nn.Parameter, Parameter or None expected)"
)
self
.
register_parameter
(
name
,
value
)
else
:
...
...
@@ -210,7 +214,7 @@ class InfiniCoreModule:
self
.
_modules
[
name
]
=
module
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]
])
->
None
:
r
"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
...
...
@@ -242,12 +246,12 @@ class InfiniCoreModule:
if
param
is
None
:
self
.
_parameters
[
name
]
=
None
else
:
# Support both torch.nn.Parameter and InfiniCoreParameter
from
.parameter
import
InfiniCoreParameter
if
not
isinstance
(
param
,
(
torch
.
nn
.
Parameter
,
InfiniCore
Parameter
)):
# Support both torch.nn.Parameter and
Parameter (
InfiniCoreParameter
)
from
.parameter
import
InfiniCoreParameter
as
Parameter
if
not
isinstance
(
param
,
(
torch
.
nn
.
Parameter
,
Parameter
)):
raise
TypeError
(
f
"cannot assign '
{
torch
.
typename
(
param
)
}
' object to parameter '
{
name
}
' "
"(torch.nn.Parameter,
InfiniCore
Parameter or None required)"
"(torch.nn.Parameter, Parameter or None required)"
)
self
.
_parameters
[
name
]
=
param
...
...
@@ -557,7 +561,7 @@ class InfiniCoreModule:
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
_IncompatibleKeys
(
missing_keys
,
unexpected_keys
)
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
torch
.
nn
.
Parameter
]:
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]
]:
r
"""Returns an iterator over module parameters.
Args:
...
...
@@ -578,7 +582,7 @@ class InfiniCoreModule:
for
name
,
param
in
self
.
named_parameters
(
recurse
=
recurse
):
yield
param
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
torch
.
nn
.
Parameter
]]:
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
Union
[
torch
.
nn
.
Parameter
,
'Parameter'
]
]]:
r
"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
...
...
@@ -845,6 +849,9 @@ class InfiniCoreModule:
return
False
should_use_swap_tensors
=
torch
.
__future__
.
get_swap_module_params_on_conversion
()
# Import Parameter (InfiniCoreParameter) for type checking and creation
from
.parameter
import
InfiniCoreParameter
as
Parameter
for
key
,
param
in
self
.
_parameters
.
items
():
if
param
is
None
:
...
...
@@ -859,6 +866,10 @@ class InfiniCoreModule:
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors
=
should_use_swap_tensors
or
is_traceable_wrapper_subclass
(
param_applied
)
# Determine the Parameter class to use based on the original parameter type
is_infinicore_param
=
isinstance
(
param
,
Parameter
)
ParamClass
=
Parameter
if
is_infinicore_param
else
torch
.
nn
.
Parameter
param_grad
=
param
.
grad
if
p_should_use_swap_tensors
:
try
:
...
...
@@ -866,7 +877,7 @@ class InfiniCoreModule:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param
.
grad
=
None
param_applied
=
torch
.
nn
.
Parameter
(
param_applied
,
requires_grad
=
param
.
requires_grad
)
param_applied
=
ParamClass
(
param_applied
,
requires_grad
=
param
.
requires_grad
)
torch
.
utils
.
swap_tensors
(
param
,
param_applied
)
except
Exception
as
e
:
if
param_grad
is
not
None
:
...
...
@@ -877,9 +888,9 @@ class InfiniCoreModule:
param
.
data
=
param_applied
out_param
=
param
else
:
assert
isinstance
(
param
,
torch
.
nn
.
Parameter
)
assert
isinstance
(
param
,
(
torch
.
nn
.
Parameter
,
Parameter
)
)
assert
param
.
is_leaf
out_param
=
torch
.
nn
.
Parameter
(
param_applied
,
param
.
requires_grad
)
out_param
=
ParamClass
(
param_applied
,
param
.
requires_grad
)
self
.
_parameters
[
key
]
=
out_param
if
param_grad
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment