Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0f5e66ce
Unverified
Commit
0f5e66ce
authored
Dec 06, 2025
by
PanZezhong1725
Committed by
GitHub
Dec 06, 2025
Browse files
Merge pull request #721 from InfiniTensor/issue/719-a
Module 支持单个张量加载
parents
9c4d4d1a
420369bd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
21 deletions
+35
-21
include/infinicore/nn/module.hpp
include/infinicore/nn/module.hpp
+5
-3
src/infinicore-test/test_nn_module.cc
src/infinicore-test/test_nn_module.cc
+9
-9
src/infinicore/context/context_impl.cc
src/infinicore/context/context_impl.cc
+2
-2
src/infinicore/nn/module.cc
src/infinicore/nn/module.cc
+18
-2
src/infinicore/nn/rmsnorm.cc
src/infinicore/nn/rmsnorm.cc
+0
-5
src/infinicore/tensor/tensor.cc
src/infinicore/tensor/tensor.cc
+1
-0
No files found.
include/infinicore/nn/module.hpp
View file @
0f5e66ce
#pragma once
#include "parameter.hpp"
#include "../tensor.hpp"
#include "parameter.hpp"
#include <unordered_map>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace
infinicore
::
nn
{
...
...
@@ -18,6 +18,8 @@ public:
void
load_parameter
(
const
std
::
string
&
name
,
const
Tensor
&
param
);
void
load_parameter_
(
const
std
::
string
&
name
,
const
Tensor
&
param
);
void
load_parameter_from_blob
(
const
std
::
string
&
name
,
const
void
*
data
);
protected:
...
...
@@ -135,7 +137,7 @@ private:
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \
name##_ = infinicore::nn::Parameter args;
\
this->register_parameter(#name, name##_)
// Declare a buffer member variable
...
...
src/infinicore-test/test_nn_module.cc
View file @
0f5e66ce
...
...
@@ -90,18 +90,18 @@ TestResult NNModuleTest::testBasicModuleCreation() {
auto
new_weight
=
infinicore
::
Tensor
::
ones
({
4
,
8
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
());
auto
new_bias
=
infinicore
::
Tensor
::
zeros
({
4
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
());
// Load using load_parameter
module
.
load_parameter
(
"weight"
,
new_weight
);
module
.
load_parameter
(
"bias"
,
new_bias
);
// Load using load_parameter
_
module
.
load_parameter
_
(
"weight"
,
new_weight
);
module
.
load_parameter
_
(
"bias"
,
new_bias
);
// Verify the parameters were updated
auto
updated_state_dict
=
module
.
state_dict
();
if
(
!
tensorsAllClose
(
updated_state_dict
.
at
(
"weight"
),
new_weight
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"Weight parameter values do not match after load_parameter"
);
spdlog
::
error
(
"Weight parameter values do not match after load_parameter
_
"
);
return
false
;
}
if
(
!
tensorsAllClose
(
updated_state_dict
.
at
(
"bias"
),
new_bias
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"Bias parameter values do not match after load_parameter"
);
spdlog
::
error
(
"Bias parameter values do not match after load_parameter
_
"
);
return
false
;
}
...
...
@@ -1493,14 +1493,14 @@ TestResult NNModuleTest::testDtypeAssertion() {
linear1
.
load_state_dict
(
matching_state
);
spdlog
::
debug
(
"✓ Matching dtype load succeeded"
);
// Test 2: Failed load with mismatched dtype (load_parameter)
spdlog
::
info
(
"Test 2: Failed load_parameter with mismatched dtype"
);
// Test 2: Failed load with mismatched dtype (load_parameter
_
)
spdlog
::
info
(
"Test 2: Failed load_parameter
_
with mismatched dtype"
);
infinicore
::
nn
::
Linear
linear2
(
8
,
4
,
true
);
auto
mismatched_weight
=
infinicore
::
Tensor
::
ones
({
4
,
8
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
bool
exception_thrown
=
false
;
try
{
linear2
.
load_parameter
(
"weight"
,
mismatched_weight
);
linear2
.
load_parameter
_
(
"weight"
,
mismatched_weight
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
exception_thrown
=
true
;
std
::
string
error_msg
=
e
.
what
();
...
...
@@ -1512,7 +1512,7 @@ TestResult NNModuleTest::testDtypeAssertion() {
}
if
(
!
exception_thrown
)
{
spdlog
::
error
(
"Expected exception for dtype mismatch in load_parameter"
);
spdlog
::
error
(
"Expected exception for dtype mismatch in load_parameter
_
"
);
return
false
;
}
...
...
src/infinicore/context/context_impl.cc
View file @
0f5e66ce
...
...
@@ -13,14 +13,14 @@ Runtime *ContextImpl::getCurrentRuntime() {
// Try to find the first non-CPU device, fallback to CPU
for
(
int
i
=
int
(
Device
::
Type
::
COUNT
)
-
1
;
i
>
0
;
i
--
)
{
if
(
!
runtime_table_
[
i
].
empty
()
&&
runtime_table_
[
i
][
0
]
!=
nullptr
)
{
current_runtime_
=
runtime_table_
[
i
][
0
].
get
();
current_runtime_
=
runtime_table_
[
i
][
0
].
get
()
->
activate
()
;
spdlog
::
debug
(
"Lazy init: Set current_runtime_ to {} (ptr={})"
,
current_runtime_
->
device
().
toString
(),
static_cast
<
void
*>
(
current_runtime_
));
return
current_runtime_
;
}
}
// Fallback to CPU runtime
if
(
!
runtime_table_
[
0
].
empty
()
&&
runtime_table_
[
0
][
0
]
!=
nullptr
)
{
current_runtime_
=
runtime_table_
[
0
][
0
].
get
();
current_runtime_
=
runtime_table_
[
0
][
0
].
get
()
->
activate
()
;
spdlog
::
debug
(
"Lazy init: Set current_runtime_ to {} (ptr={})"
,
current_runtime_
->
device
().
toString
(),
static_cast
<
void
*>
(
current_runtime_
));
}
}
else
{
...
...
src/infinicore/nn/module.cc
View file @
0f5e66ce
...
...
@@ -17,6 +17,22 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
}
void
Module
::
load_parameter
(
const
std
::
string
&
name
,
const
Tensor
&
param
)
{
// This function only handles direct parameters (no hierarchical traversal)
auto
all_params
=
state_dict
();
auto
it
=
all_params
.
find
(
name
);
if
(
it
!=
all_params
.
end
())
{
auto
existing_param
=
it
->
second
;
existing_param
.
load
(
param
);
return
;
}
// Parameter not found
spdlog
::
debug
(
"load_parameter_: Parameter '{}' not found. Available: {} params"
,
name
,
parameters_
.
size
());
throw
std
::
runtime_error
(
"Parameter '"
+
name
+
"' not found in module."
);
}
void
Module
::
load_parameter_
(
const
std
::
string
&
name
,
const
Tensor
&
param
)
{
// This function only handles direct parameters (no hierarchical traversal)
auto
it
=
parameters_
.
find
(
name
);
if
(
it
!=
parameters_
.
end
())
{
...
...
@@ -33,7 +49,7 @@ void Module::load_parameter(const std::string &name, const Tensor ¶m) {
}
// Parameter not found
spdlog
::
debug
(
"load_parameter: Parameter '{}' not found. Available: {} params"
,
spdlog
::
debug
(
"load_parameter
_
: Parameter '{}' not found. Available: {} params"
,
name
,
parameters_
.
size
());
throw
std
::
runtime_error
(
"Parameter '"
+
name
+
"' not found in module."
);
}
...
...
@@ -59,7 +75,7 @@ void Module::load_state_dict_recursively(const std::unordered_map<std::string, T
std
::
string
full_name
=
prefix
.
empty
()
?
param_name
:
prefix
+
"."
+
param_name
;
auto
it
=
_state_dict
.
find
(
full_name
);
if
(
it
!=
_state_dict
.
end
())
{
load_parameter
(
param_name
,
it
->
second
);
load_parameter
_
(
param_name
,
it
->
second
);
}
}
...
...
src/infinicore/nn/rmsnorm.cc
View file @
0f5e66ce
...
...
@@ -12,12 +12,7 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con
device_
=
device
;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
normalized_shape
},
dtype_
,
device
));
// Initialize weight to ones (standard practice for RMSNorm)
auto
ones_tensor
=
Tensor
::
ones
({
normalized_shape
},
dtype_
,
device
);
weight_
->
copy_from
(
ones_tensor
);
}
Tensor
RMSNorm
::
forward
(
const
Tensor
&
x
)
const
{
...
...
src/infinicore/tensor/tensor.cc
View file @
0f5e66ce
...
...
@@ -162,6 +162,7 @@ std::string TensorImpl::info() const {
ss
<<
s
<<
" "
;
}
ss
<<
"] dtype="
<<
toString
(
this
->
dtype
());
ss
<<
" device="
<<
this
->
device
().
toString
();
return
ss
.
str
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment