Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
777b3233
Commit
777b3233
authored
Oct 31, 2025
by
Ceng23333
Browse files
do assertion at load_parameter && update Module definition with macros
Signed-off-by:
Ceng23333
<
441651826@qq.com
>
parent
69c1c352
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
202 additions
and
51 deletions
+202
-51
include/infinicore/nn/embedding.hpp
include/infinicore/nn/embedding.hpp
+1
-1
include/infinicore/nn/linear.hpp
include/infinicore/nn/linear.hpp
+5
-3
include/infinicore/nn/module.hpp
include/infinicore/nn/module.hpp
+2
-2
include/infinicore/nn/rmsnorm.hpp
include/infinicore/nn/rmsnorm.hpp
+5
-1
src/infinicore-test/test_nn_module.cc
src/infinicore-test/test_nn_module.cc
+140
-16
src/infinicore-test/test_nn_module.h
src/infinicore-test/test_nn_module.h
+14
-8
src/infinicore/nn/linear.cc
src/infinicore/nn/linear.cc
+10
-12
src/infinicore/nn/module.cc
src/infinicore/nn/module.cc
+17
-1
src/infinicore/nn/rmsnorm.cc
src/infinicore/nn/rmsnorm.cc
+8
-7
No files found.
include/infinicore/nn/embedding.hpp
View file @
777b3233
...
...
@@ -75,7 +75,7 @@ public:
protected:
// Parameters
Parameter
weight
_
;
INFINICORE_NN_PARAMETER
(
weight
)
;
private:
size_t
num_embeddings_
;
// Vocabulary size
...
...
include/infinicore/nn/linear.hpp
View file @
777b3233
...
...
@@ -7,7 +7,7 @@ namespace infinicore::nn {
class
Linear
:
public
Module
{
public:
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
Device
&
device
=
Device
());
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
...
...
@@ -20,6 +20,7 @@ public:
size_t
in_features
()
const
{
return
in_features_
;
}
size_t
out_features
()
const
{
return
out_features_
;
}
bool
has_bias
()
const
{
return
has_bias_
;
}
DataType
dtype
()
const
{
return
dtype_
;
}
// String representation
std
::
string
extra_repr
()
const
;
...
...
@@ -30,8 +31,8 @@ public:
protected:
// Parameters
Parameter
weight
_
;
Parameter
bias
_
;
INFINICORE_NN_PARAMETER
(
weight
)
;
INFINICORE_NN_PARAMETER
(
bias
)
;
private:
// Helper method for common forward computation
...
...
@@ -40,6 +41,7 @@ private:
size_t
in_features_
;
size_t
out_features_
;
bool
has_bias_
;
DataType
dtype_
;
};
}
// namespace infinicore::nn
include/infinicore/nn/module.hpp
View file @
777b3233
...
...
@@ -125,13 +125,13 @@ private:
// Declare a parameter member variable
#define INFINICORE_NN_PARAMETER(name) \
Parameter name##_
infinicore::nn::
Parameter name##_
// Initialize a parameter in constructor
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = Parameter args; \
name##_ =
infinicore::nn::
Parameter args; \
this->register_parameter(#name, name##_)
}
// namespace infinicore::nn
include/infinicore/nn/rmsnorm.hpp
View file @
777b3233
...
...
@@ -36,10 +36,12 @@ public:
*
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
* @param eps Small constant for numerical stability (default: 1e-6)
* @param dtype Data type for the weight (default: DataType::F32)
* @param device Device to create the weight on
*/
RMSNorm
(
size_t
normalized_shape
,
double
eps
=
1e-6
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
/**
...
...
@@ -58,6 +60,7 @@ public:
// Module information
size_t
normalized_shape
()
const
{
return
normalized_shape_
;
}
double
eps
()
const
{
return
eps_
;
}
DataType
dtype
()
const
{
return
dtype_
;
}
// String representation
std
::
string
extra_repr
()
const
;
...
...
@@ -67,11 +70,12 @@ public:
protected:
// Parameters
Parameter
weight
_
;
INFINICORE_NN_PARAMETER
(
weight
)
;
private:
size_t
normalized_shape_
;
// Size of the feature dimension
double
eps_
;
// Epsilon for numerical stability
DataType
dtype_
;
// Data type for weight
};
}
// namespace infinicore::nn
src/infinicore-test/test_nn_module.cc
View file @
777b3233
...
...
@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() {
try
{
// Test with bias
spdlog
::
info
(
"Testing Linear module with bias (8->4 features)"
);
infinicore
::
nn
::
Linear
m1
(
8
,
4
,
true
,
infinicore
::
Device
()
);
infinicore
::
nn
::
Linear
m1
(
8
,
4
,
true
);
auto
sd1
=
m1
.
state_dict
();
if
(
sd1
.
find
(
"weight"
)
==
sd1
.
end
())
{
spdlog
::
error
(
"weight missing"
);
...
...
@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() {
// Test without bias
spdlog
::
info
(
"Testing Linear module without bias (16->3 features)"
);
infinicore
::
nn
::
Linear
m2
(
16
,
3
,
false
,
infinicore
::
Device
()
);
infinicore
::
nn
::
Linear
m2
(
16
,
3
,
false
);
auto
sd2
=
m2
.
state_dict
();
if
(
sd2
.
find
(
"weight"
)
==
sd2
.
end
())
{
spdlog
::
error
(
"weight missing (no-bias)"
);
...
...
@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 1: Basic RMSNorm creation
spdlog
::
info
(
"Test 1: Basic RMSNorm creation (hidden_size=768)"
);
infinicore
::
nn
::
RMSNorm
norm1
(
768
,
1e-6
,
infinicore
::
Device
()
);
infinicore
::
nn
::
RMSNorm
norm1
(
768
);
auto
state1
=
norm1
.
state_dict
();
if
(
state1
.
find
(
"weight"
)
==
state1
.
end
())
{
...
...
@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 7: Different hidden sizes
spdlog
::
info
(
"Test 7: Testing different hidden sizes"
);
infinicore
::
nn
::
RMSNorm
norm_small
(
128
,
1e-5
,
infinicore
::
Device
()
);
infinicore
::
nn
::
RMSNorm
norm_large
(
4096
,
1e-6
,
infinicore
::
Device
()
);
infinicore
::
nn
::
RMSNorm
norm_small
(
128
,
1e-5
);
infinicore
::
nn
::
RMSNorm
norm_large
(
4096
);
auto
input_small
=
infinicore
::
Tensor
::
ones
({
2
,
128
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
());
auto
output_small
=
norm_small
.
forward
(
input_small
);
...
...
@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() {
});
}
// Test 8: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
// Test 8: Dtype assertion test
TestResult
NNModuleTest
::
testDtypeAssertion
()
{
return
measureTime
(
"DtypeAssertionTest"
,
[
this
]()
{
try
{
spdlog
::
info
(
"Testing dtype assertions when loading parameters"
);
// Test 1: Successful load with matching dtype
spdlog
::
info
(
"Test 1: Successful load with matching dtype (F32)"
);
infinicore
::
nn
::
Linear
linear1
(
8
,
4
,
true
);
auto
matching_weight
=
infinicore
::
Tensor
::
ones
({
4
,
8
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
());
auto
matching_bias
=
infinicore
::
Tensor
::
ones
({
4
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
());
std
::
unordered_map
<
std
::
string
,
infinicore
::
Tensor
>
matching_state
;
matching_state
.
emplace
(
"weight"
,
matching_weight
);
matching_state
.
emplace
(
"bias"
,
matching_bias
);
// This should succeed without throwing
linear1
.
load_state_dict
(
matching_state
);
spdlog
::
debug
(
"✓ Matching dtype load succeeded"
);
// Test 2: Failed load with mismatched dtype (load_parameter)
spdlog
::
info
(
"Test 2: Failed load_parameter with mismatched dtype"
);
infinicore
::
nn
::
Linear
linear2
(
8
,
4
,
true
);
auto
mismatched_weight
=
infinicore
::
Tensor
::
ones
({
4
,
8
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
bool
exception_thrown
=
false
;
try
{
linear2
.
load_parameter
(
"weight"
,
mismatched_weight
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
exception_thrown
=
true
;
std
::
string
error_msg
=
e
.
what
();
if
(
error_msg
.
find
(
"dtype mismatch"
)
==
std
::
string
::
npos
)
{
spdlog
::
error
(
"Exception message doesn't contain 'dtype mismatch'"
);
return
false
;
}
spdlog
::
debug
(
"✓ Mismatched dtype exception caught: {}"
,
error_msg
);
}
if
(
!
exception_thrown
)
{
spdlog
::
error
(
"Expected exception for dtype mismatch in load_parameter"
);
return
false
;
}
// Test 3: Failed load with mismatched dtype (load_state_dict)
spdlog
::
info
(
"Test 3: Failed load_state_dict with mismatched dtype"
);
infinicore
::
nn
::
Embedding
embedding1
(
100
,
64
);
auto
mismatched_embed_weight
=
infinicore
::
Tensor
::
ones
({
100
,
64
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
std
::
unordered_map
<
std
::
string
,
infinicore
::
Tensor
>
mismatched_state
;
mismatched_state
.
emplace
(
"weight"
,
mismatched_embed_weight
);
exception_thrown
=
false
;
try
{
embedding1
.
load_state_dict
(
mismatched_state
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
exception_thrown
=
true
;
std
::
string
error_msg
=
e
.
what
();
if
(
error_msg
.
find
(
"dtype mismatch"
)
==
std
::
string
::
npos
)
{
spdlog
::
error
(
"Exception message doesn't contain 'dtype mismatch'"
);
return
false
;
}
if
(
error_msg
.
find
(
"weight"
)
==
std
::
string
::
npos
)
{
spdlog
::
error
(
"Exception message doesn't contain parameter name 'weight'"
);
return
false
;
}
spdlog
::
debug
(
"✓ Mismatched dtype exception caught: {}"
,
error_msg
);
}
if
(
!
exception_thrown
)
{
spdlog
::
error
(
"Expected exception for dtype mismatch in load_state_dict"
);
return
false
;
}
// Test 4: Failed load with mismatched dtype (RMSNorm)
spdlog
::
info
(
"Test 4: Failed load_state_dict with mismatched dtype (RMSNorm)"
);
infinicore
::
nn
::
RMSNorm
norm1
(
768
);
auto
mismatched_norm_weight
=
infinicore
::
Tensor
::
ones
({
768
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
std
::
unordered_map
<
std
::
string
,
infinicore
::
Tensor
>
mismatched_norm_state
;
mismatched_norm_state
.
emplace
(
"weight"
,
mismatched_norm_weight
);
exception_thrown
=
false
;
try
{
norm1
.
load_state_dict
(
mismatched_norm_state
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
exception_thrown
=
true
;
std
::
string
error_msg
=
e
.
what
();
if
(
error_msg
.
find
(
"dtype mismatch"
)
==
std
::
string
::
npos
)
{
spdlog
::
error
(
"Exception message doesn't contain 'dtype mismatch'"
);
return
false
;
}
spdlog
::
debug
(
"✓ Mismatched dtype exception caught for RMSNorm: {}"
,
error_msg
);
}
if
(
!
exception_thrown
)
{
spdlog
::
error
(
"Expected exception for dtype mismatch in RMSNorm load_state_dict"
);
return
false
;
}
// Test 5: Successful load with different module dtypes
spdlog
::
info
(
"Test 5: Successful load with BF16 dtype (module created with BF16)"
);
infinicore
::
nn
::
Linear
linear3
(
8
,
4
,
true
,
infinicore
::
DataType
::
BF16
);
auto
bf16_weight
=
infinicore
::
Tensor
::
ones
({
4
,
8
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
auto
bf16_bias
=
infinicore
::
Tensor
::
ones
({
4
},
infinicore
::
DataType
::
BF16
,
infinicore
::
Device
());
std
::
unordered_map
<
std
::
string
,
infinicore
::
Tensor
>
bf16_state
;
bf16_state
.
emplace
(
"weight"
,
bf16_weight
);
bf16_state
.
emplace
(
"bias"
,
bf16_bias
);
// This should succeed
linear3
.
load_state_dict
(
bf16_state
);
spdlog
::
debug
(
"✓ BF16 dtype load succeeded"
);
spdlog
::
info
(
"All dtype assertion tests passed!"
);
return
true
;
}
catch
(
const
std
::
exception
&
e
)
{
spdlog
::
error
(
"Exception in testDtypeAssertion: {}"
,
e
.
what
());
return
false
;
}
});
}
// Test 9: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
TestResult
NNModuleTest
::
testTinyLlamaConstruction
()
{
return
measureTime
(
"TinyLlamaModelTest"
,
[
this
]()
{
try
{
...
...
@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
Linear
,
o_proj
);
SelfAttn
(
size_t
hidden_size
,
size_t
kv_dim
,
const
infinicore
::
Device
&
device
)
{
INFINICORE_NN_MODULE_INIT
(
q_proj
,
hidden_size
,
hidden_size
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
k_proj
,
hidden_size
,
kv_dim
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
v_proj
,
hidden_size
,
kv_dim
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
o_proj
,
hidden_size
,
hidden_size
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
q_proj
,
hidden_size
,
hidden_size
,
false
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
k_proj
,
hidden_size
,
kv_dim
,
false
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
v_proj
,
hidden_size
,
kv_dim
,
false
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
o_proj
,
hidden_size
,
hidden_size
,
false
,
infinicore
::
DataType
::
F32
,
device
);
}
};
...
...
@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
Linear
,
down_proj
);
MLP
(
size_t
hidden_size
,
size_t
intermediate_size
,
const
infinicore
::
Device
&
device
)
{
INFINICORE_NN_MODULE_INIT
(
gate_proj
,
hidden_size
,
intermediate_size
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
up_proj
,
hidden_size
,
intermediate_size
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
down_proj
,
intermediate_size
,
hidden_size
,
false
,
device
);
INFINICORE_NN_MODULE_INIT
(
gate_proj
,
hidden_size
,
intermediate_size
,
false
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
up_proj
,
hidden_size
,
intermediate_size
,
false
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
down_proj
,
intermediate_size
,
hidden_size
,
false
,
infinicore
::
DataType
::
F32
,
device
);
}
};
...
...
@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
Block
(
const
TinyLlamaConfig
&
cfg
,
const
infinicore
::
Device
&
device
)
{
size_t
kv_dim
=
cfg
.
hidden_size
*
cfg
.
num_key_value_heads
/
cfg
.
num_attention_heads
;
INFINICORE_NN_MODULE_INIT
(
input_layernorm
,
cfg
.
hidden_size
,
cfg
.
rms_norm_eps
,
device
);
INFINICORE_NN_MODULE_INIT
(
input_layernorm
,
cfg
.
hidden_size
,
cfg
.
rms_norm_eps
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
self_attn
,
cfg
.
hidden_size
,
kv_dim
,
device
);
INFINICORE_NN_MODULE_INIT
(
post_attention_layernorm
,
cfg
.
hidden_size
,
cfg
.
rms_norm_eps
,
device
);
INFINICORE_NN_MODULE_INIT
(
post_attention_layernorm
,
cfg
.
hidden_size
,
cfg
.
rms_norm_eps
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_INIT
(
mlp
,
cfg
.
hidden_size
,
cfg
.
intermediate_size
,
device
);
}
};
...
...
@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
TinyLlamaModel
(
const
TinyLlamaConfig
&
config
,
const
infinicore
::
Device
&
device
)
{
INFINICORE_NN_MODULE_INIT
(
embed_tokens
,
config
.
vocab_size
,
config
.
hidden_size
,
std
::
nullopt
,
infinicore
::
DataType
::
F32
,
device
);
INFINICORE_NN_MODULE_VEC_INIT
(
layers
,
config
.
num_hidden_layers
,
Block
,
config
,
device
);
INFINICORE_NN_MODULE_INIT
(
norm
,
config
.
hidden_size
,
config
.
rms_norm_eps
,
device
);
INFINICORE_NN_MODULE_INIT
(
norm
,
config
.
hidden_size
,
config
.
rms_norm_eps
,
infinicore
::
DataType
::
F32
,
device
);
}
};
...
...
@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() {
results
.
push_back
(
testModuleLinear
());
// Linear module comprehensive test
results
.
push_back
(
testModuleEmbedding
());
// Embedding module test
results
.
push_back
(
testModuleRMSNorm
());
// RMSNorm module test
results
.
push_back
(
testDtypeAssertion
());
// Dtype assertion test
results
.
push_back
(
testTinyLlamaConstruction
());
// Comprehensive: TinyLlama model test
// Check if all tests passed
...
...
src/infinicore-test/test_nn_module.h
View file @
777b3233
...
...
@@ -21,16 +21,21 @@ namespace infinicore::test {
// Simple test module that mimics torch.nn.Linear
class
MockLinearModule
:
public
infinicore
::
nn
::
Module
{
public:
// Declare parameters using macros (torch-like style)
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
bias
);
MockLinearModule
(
int
input_size
,
int
output_size
,
const
infinicore
::
Device
&
device
)
:
input_size_
(
input_size
),
output_size_
(
output_size
),
device_
(
device
)
{
// Initialize weight parameter (similar to torch.nn.Linear.weight)
register_parameter
(
"weight"
,
infinicore
::
nn
::
Parameter
({
static_cast
<
size_t
>
(
output_size
),
static_cast
<
size_t
>
(
input_size
)},
infinicore
::
DataType
::
F32
,
device
));
// Initialize bias parameter (similar to torch.nn.Linear.bias)
register_parameter
(
"bias"
,
infinicore
::
nn
::
Parameter
({
static_cast
<
size_t
>
(
output_size
)},
infinicore
::
DataType
::
F32
,
device
));
// Initialize parameters using macros
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
static_cast
<
size_t
>
(
output_size
),
static_cast
<
size_t
>
(
input_size
)},
infinicore
::
DataType
::
F32
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
static_cast
<
size_t
>
(
output_size
)},
infinicore
::
DataType
::
F32
,
device
));
}
// Simple forward pass (conceptual - would need actual matrix operations)
...
...
@@ -77,6 +82,7 @@ private:
TestResult
testModuleLinear
();
// Comprehensive Linear module test
TestResult
testModuleEmbedding
();
// Embedding module test
TestResult
testModuleRMSNorm
();
// RMSNorm module test
TestResult
testDtypeAssertion
();
// Test dtype assertions when loading parameters
TestResult
testTinyLlamaConstruction
();
// Comprehensive: construction + weight loading + validation
};
...
...
src/infinicore/nn/linear.cc
View file @
777b3233
...
...
@@ -4,25 +4,26 @@
namespace
infinicore
::
nn
{
Linear
::
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
Device
&
device
)
Linear
::
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
)
:
in_features_
(
in_features
),
out_features_
(
out_features
),
has_bias_
(
bias
)
{
has_bias_
(
bias
),
dtype_
(
dtype
)
{
device_
=
device
;
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
DataType
::
F32
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
));
// Register bias parameter if requested
if
(
bias
)
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
DataType
::
F32
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
));
}
else
{
bias_
=
Parameter
();
// Default constructed empty parameter
}
spdlog
::
debug
(
"Created Linear module: in_features={}, out_features={}, bias={}"
,
in_features
,
out_features
,
bias
);
spdlog
::
debug
(
"Created Linear module: in_features={}, out_features={}, bias={}
, dtype={}
"
,
in_features
,
out_features
,
bias
,
static_cast
<
int
>
(
dtype_
)
);
}
Tensor
Linear
::
compute_linear
(
Tensor
&
input
)
const
{
...
...
@@ -41,12 +42,9 @@ Tensor Linear::compute_linear(Tensor &input) const {
strides
.
push_back
(
bias_
->
stride
(
0
));
auto
bias_view
=
bias_
->
as_strided
(
output
->
shape
(),
strides
);
// First set output to bias (broadcasted)
infinicore
::
op
::
rearrange_
(
output
,
bias_view
);
// Compute matmul result separately, then add to output
auto
matmul_result
=
infinicore
::
op
::
matmul
(
input
,
weight_t
);
infinicore
::
op
::
add_
(
output
,
output
,
matmul_result
);
infinicore
::
op
::
matmul
_
(
output
,
input
,
weight_t
);
infinicore
::
op
::
add_
(
output
,
output
,
bias_view
);
}
else
{
// No bias: just compute output = input @ weight_t
infinicore
::
op
::
matmul_
(
output
,
input
,
weight_t
);
...
...
@@ -69,7 +67,7 @@ Tensor Linear::forward(Tensor &input, Tensor &residual) const {
}
std
::
string
Linear
::
extra_repr
()
const
{
return
"Linear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
")"
;
return
"Linear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
}
// namespace infinicore::nn
src/infinicore/nn/module.cc
View file @
777b3233
#include "infinicore/nn/module.hpp"
#include <stdexcept>
namespace
infinicore
::
nn
{
const
std
::
unordered_map
<
std
::
string
,
Parameter
>
&
Module
::
state_dict
()
const
{
...
...
@@ -20,13 +21,28 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
// Look up the corresponding tensor in the input state dict using the full name
auto
it
=
_state_dict
.
find
(
param_full_name
);
if
(
it
!=
_state_dict
.
end
())
{
// Assert dtype matches
if
(
param
->
dtype
()
!=
it
->
second
->
dtype
())
{
throw
std
::
runtime_error
(
"dtype mismatch for parameter '"
+
param_full_name
+
"': "
"expected "
+
std
::
to_string
(
static_cast
<
int
>
(
param
->
dtype
()))
+
", got "
+
std
::
to_string
(
static_cast
<
int
>
(
it
->
second
->
dtype
())));
}
param
->
copy_from
(
it
->
second
);
}
}
}
void
Module
::
load_parameter
(
const
std
::
string
&
name
,
const
Tensor
&
param
)
{
parameters_
[
name
]
->
copy_from
(
param
);
auto
existing_param
=
parameters_
[
name
];
// Assert dtype matches
if
(
existing_param
->
dtype
()
!=
param
->
dtype
())
{
throw
std
::
runtime_error
(
"dtype mismatch for parameter '"
+
name
+
"': "
"expected "
+
std
::
to_string
(
static_cast
<
int
>
(
existing_param
->
dtype
()))
+
", got "
+
std
::
to_string
(
static_cast
<
int
>
(
param
->
dtype
())));
}
existing_param
->
copy_from
(
param
);
}
void
Module
::
load_parameter_from_blob
(
const
std
::
string
&
name
,
const
void
*
data
)
{
...
...
src/infinicore/nn/rmsnorm.cc
View file @
777b3233
...
...
@@ -6,21 +6,22 @@
namespace
infinicore
::
nn
{
RMSNorm
::
RMSNorm
(
size_t
normalized_shape
,
double
eps
,
const
Device
&
device
)
RMSNorm
::
RMSNorm
(
size_t
normalized_shape
,
double
eps
,
const
DataType
&
dtype
,
const
Device
&
device
)
:
normalized_shape_
(
normalized_shape
),
eps_
(
eps
)
{
eps_
(
eps
),
dtype_
(
dtype
)
{
device_
=
device
;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
normalized_shape
},
DataType
::
F32
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
normalized_shape
},
dtype_
,
device
));
// Initialize weight to ones (standard practice for RMSNorm)
auto
ones_tensor
=
Tensor
::
ones
({
normalized_shape
},
DataType
::
F32
,
device
);
auto
ones_tensor
=
Tensor
::
ones
({
normalized_shape
},
dtype_
,
device
);
weight_
->
copy_from
(
ones_tensor
);
spdlog
::
debug
(
"Created RMSNorm module: normalized_shape={}, eps={}"
,
normalized_shape
,
eps
);
spdlog
::
debug
(
"Created RMSNorm module: normalized_shape={}, eps={}
, dtype={}
"
,
normalized_shape
,
eps
,
static_cast
<
int
>
(
dtype_
)
);
}
Tensor
RMSNorm
::
forward
(
const
Tensor
&
x
)
const
{
...
...
@@ -37,7 +38,7 @@ Tensor RMSNorm::forward(const Tensor &x) const {
}
std
::
string
RMSNorm
::
extra_repr
()
const
{
return
"RMSNorm(normalized_shape="
+
std
::
to_string
(
normalized_shape_
)
+
", eps="
+
std
::
to_string
(
eps_
)
+
")"
;
return
"RMSNorm(normalized_shape="
+
std
::
to_string
(
normalized_shape_
)
+
", eps="
+
std
::
to_string
(
eps_
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
}
// namespace infinicore::nn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment