Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0c6803c6
Commit
0c6803c6
authored
Dec 03, 2025
by
Ceng23333
Browse files
issue/697: fix load_state_dict
Signed-off-by:
Ceng23333
<
441651826@qq.com
>
parent
986bb179
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
16 deletions
+26
-16
include/infinicore/nn/module.hpp
include/infinicore/nn/module.hpp
+1
-0
src/infinicore/nn/module.cc
src/infinicore/nn/module.cc
+25
-16
No files found.
include/infinicore/nn/module.hpp
View file @
0c6803c6
...
@@ -78,6 +78,7 @@ protected:
...
@@ -78,6 +78,7 @@ protected:
std
::
unordered_map
<
std
::
string
,
Parameter
>
parameters_
;
std
::
unordered_map
<
std
::
string
,
Parameter
>
parameters_
;
private:
private:
void
load_state_dict_recursively
(
const
std
::
unordered_map
<
std
::
string
,
Tensor
>
&
_state_dict
,
const
std
::
string
&
prefix
=
""
);
void
collect_all_parameters
(
std
::
unordered_map
<
std
::
string
,
Parameter
>
&
all_params
,
const
std
::
string
&
prefix
=
""
)
const
;
void
collect_all_parameters
(
std
::
unordered_map
<
std
::
string
,
Parameter
>
&
all_params
,
const
std
::
string
&
prefix
=
""
)
const
;
};
};
...
...
src/infinicore/nn/module.cc
View file @
0c6803c6
...
@@ -13,23 +13,11 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
...
@@ -13,23 +13,11 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
}
}
void
Module
::
load_state_dict
(
const
std
::
unordered_map
<
std
::
string
,
Tensor
>
&
_state_dict
)
{
void
Module
::
load_state_dict
(
const
std
::
unordered_map
<
std
::
string
,
Tensor
>
&
_state_dict
)
{
// Collect all parameters from this module and its submodules with their full hierarchical names
load_state_dict_recursively
(
_state_dict
,
""
);
std
::
unordered_map
<
std
::
string
,
Parameter
>
all_params
;
collect_all_parameters
(
all_params
,
""
);
// For each parameter in this module hierarchy, load from the state dict
for
(
auto
&
[
param_full_name
,
param
]
:
all_params
)
{
// Look up the corresponding tensor in the input state dict using the full name
auto
it
=
_state_dict
.
find
(
param_full_name
);
if
(
it
!=
_state_dict
.
end
())
{
this
->
load_parameter
(
param_full_name
,
it
->
second
);
}
else
{
spdlog
::
warn
(
"Parameter '{}' provided but not found in module."
,
param_full_name
);
}
}
}
}
void
Module
::
load_parameter
(
const
std
::
string
&
name
,
const
Tensor
&
param
)
{
void
Module
::
load_parameter
(
const
std
::
string
&
name
,
const
Tensor
&
param
)
{
// This function only handles direct parameters (no hierarchical traversal)
auto
it
=
parameters_
.
find
(
name
);
auto
it
=
parameters_
.
find
(
name
);
if
(
it
!=
parameters_
.
end
())
{
if
(
it
!=
parameters_
.
end
())
{
auto
existing_param
=
it
->
second
;
auto
existing_param
=
it
->
second
;
...
@@ -41,9 +29,13 @@ void Module::load_parameter(const std::string &name, const Tensor ¶m) {
...
@@ -41,9 +29,13 @@ void Module::load_parameter(const std::string &name, const Tensor ¶m) {
+
std
::
to_string
(
static_cast
<
int
>
(
existing_param
->
dtype
()))
+
", got "
+
std
::
to_string
(
static_cast
<
int
>
(
param
->
dtype
())));
+
std
::
to_string
(
static_cast
<
int
>
(
existing_param
->
dtype
()))
+
", got "
+
std
::
to_string
(
static_cast
<
int
>
(
param
->
dtype
())));
}
}
existing_param
.
load
(
param
);
existing_param
.
load
(
param
);
}
else
{
return
;
throw
std
::
runtime_error
(
"Parameter '"
+
name
+
"' not found in module."
);
}
}
// Parameter not found
spdlog
::
debug
(
"load_parameter: Parameter '{}' not found. Available: {} params"
,
name
,
parameters_
.
size
());
throw
std
::
runtime_error
(
"Parameter '"
+
name
+
"' not found in module."
);
}
}
void
Module
::
load_parameter_from_blob
(
const
std
::
string
&
name
,
const
void
*
data
)
{
void
Module
::
load_parameter_from_blob
(
const
std
::
string
&
name
,
const
void
*
data
)
{
...
@@ -61,6 +53,23 @@ Tensor Module::register_buffer(const std::string &name, Parameter buffer) {
...
@@ -61,6 +53,23 @@ Tensor Module::register_buffer(const std::string &name, Parameter buffer) {
return
buffer
;
return
buffer
;
}
}
void
Module
::
load_state_dict_recursively
(
const
std
::
unordered_map
<
std
::
string
,
Tensor
>
&
_state_dict
,
const
std
::
string
&
prefix
)
{
// Load direct parameters with the given prefix
for
(
const
auto
&
[
param_name
,
param
]
:
parameters_
)
{
std
::
string
full_name
=
prefix
.
empty
()
?
param_name
:
prefix
+
"."
+
param_name
;
auto
it
=
_state_dict
.
find
(
full_name
);
if
(
it
!=
_state_dict
.
end
())
{
load_parameter
(
param_name
,
it
->
second
);
}
}
// Recursively load parameters from submodules with extended prefix
for
(
const
auto
&
[
sub_name
,
submodule
]
:
submodules_
)
{
std
::
string
sub_prefix
=
prefix
.
empty
()
?
sub_name
:
prefix
+
"."
+
sub_name
;
submodule
->
load_state_dict_recursively
(
_state_dict
,
sub_prefix
);
}
}
void
Module
::
collect_all_parameters
(
std
::
unordered_map
<
std
::
string
,
Parameter
>
&
all_params
,
const
std
::
string
&
prefix
)
const
{
void
Module
::
collect_all_parameters
(
std
::
unordered_map
<
std
::
string
,
Parameter
>
&
all_params
,
const
std
::
string
&
prefix
)
const
{
// Add direct parameters with the given prefix
// Add direct parameters with the given prefix
for
(
const
auto
&
[
param_name
,
param
]
:
parameters_
)
{
for
(
const
auto
&
[
param_name
,
param
]
:
parameters_
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment