Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
571e0ba0
Commit
571e0ba0
authored
Jan 22, 2026
by
wangpengcheng
Browse files
issue/199 - 支持qwen3模型
parent
c73ff203
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
3 deletions
+32
-3
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+18
-1
csrc/models/llama/llama_attention.hpp
csrc/models/llama/llama_attention.hpp
+4
-1
csrc/models/llama/llama_config.hpp
csrc/models/llama/llama_config.hpp
+1
-0
csrc/pybind11/models/llama.hpp
csrc/pybind11/models/llama.hpp
+2
-0
python/infinilm/auto_config.py
python/infinilm/auto_config.py
+3
-1
python/infinilm/models/llama/configuration_llama.py
python/infinilm/models/llama/configuration_llama.py
+4
-0
No files found.
csrc/models/llama/llama_attention.cpp
View file @
571e0ba0
...
...
@@ -29,6 +29,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
kv_dim_
(
config
.
kv_dim
()),
use_bias_
(
config
.
attention_bias
),
use_output_bias_
(
config
.
attention_output_bias
),
use_qk_norm_
(
config
.
qk_norm
),
max_position_embeddings_
(
config
.
max_position_embeddings
),
rank_info_
(
rank_info
)
{
const
auto
&
dtype
{
config
.
dtype
};
...
...
@@ -50,8 +51,14 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
INFINILM_QKV_LINEAR_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
config
.
num_attention_heads
,
config
.
num_key_value_heads
,
use_bias_
,
dtype
,
device
,
rank_info
);
// Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT
(
o_proj
,
hidden_size
_
,
hidden_size_
,
use_output_bias_
,
INFINICORE_NN_MODULE_INIT
(
o_proj
,
num_attention_heads
*
head_dim
_
,
hidden_size_
,
use_output_bias_
,
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
// Initialize qk RMSNorm
if
(
use_qk_norm_
)
{
INFINICORE_NN_MODULE_INIT
(
q_norm
,
head_dim_
,
config
.
rms_norm_eps
,
dtype
,
device
);
INFINICORE_NN_MODULE_INIT
(
k_norm
,
head_dim_
,
config
.
rms_norm_eps
,
dtype
,
device
);
}
}
infinicore
::
Tensor
LlamaAttention
::
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
...
...
@@ -68,6 +75,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
// 1. Project Q, K, V
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
if
(
use_qk_norm_
)
{
q
=
q_norm_
->
forward
(
q
->
view
({
batch_size
*
seq_len
,
num_attention_heads_
,
head_dim_
}));
k
=
k_norm_
->
forward
(
k
->
view
({
batch_size
*
seq_len
,
num_key_value_heads_
,
head_dim_
}));
}
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
...
...
@@ -172,6 +184,11 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto
k_reshaped
=
k
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
auto
v_reshaped
=
v
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
if
(
use_qk_norm_
)
{
q_reshaped
=
q_norm_
->
forward
(
q_reshaped
);
k_reshaped
=
k_norm_
->
forward
(
k_reshaped
);
}
// 3. Prepare position_ids for RoPE - align with Python pattern
auto
pos_shape
=
position_ids
->
shape
();
infinicore
::
Tensor
pos_ids_for_rope
=
position_ids
;
...
...
csrc/models/llama/llama_attention.hpp
View file @
571e0ba0
...
...
@@ -7,6 +7,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
#include "llama_config.hpp"
...
...
@@ -92,7 +93,8 @@ protected:
// Projection layers
INFINICORE_NN_MODULE
(
infinilm
::
layers
::
QKVParallelLinear
,
qkv_proj
);
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
RowParallelLinear
,
o_proj
);
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
RMSNorm
,
q_norm
);
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
RMSNorm
,
k_norm
);
engine
::
distributed
::
RankInfo
rank_info_
;
// Shared Rotary Position Embeddings (RoPE)
...
...
@@ -107,6 +109,7 @@ private:
size_t
kv_dim_
;
bool
use_bias_
;
// Bias for Q/K/V projections
bool
use_output_bias_
;
// Bias for output projection (o_proj)
bool
use_qk_norm_
;
// Whether to use QK RMSNorm
size_t
max_position_embeddings_
;
// For cache initialization (deprecated, kept for compatibility)
float
scaling_
;
...
...
csrc/models/llama/llama_config.hpp
View file @
571e0ba0
...
...
@@ -51,6 +51,7 @@ struct LlamaConfig : public InfinilmModel::Config {
bool
attention_output_bias
=
false
;
// Whether to use bias in output projection (o_proj)
bool
mlp_bias
=
false
;
// Whether to use bias in MLP projections
bool
tie_word_embeddings
=
false
;
// Whether to tie input/output embeddings
bool
qk_norm
=
false
;
// Whether to use QK RMSNorm
// Training/initialization parameters
double
attention_dropout
=
0.0
;
// Dropout ratio for attention probabilities
...
...
csrc/pybind11/models/llama.hpp
View file @
571e0ba0
...
...
@@ -64,6 +64,7 @@ inline void bind_llama(py::module &m) {
.
def_readwrite
(
"attention_output_bias"
,
&
LlamaConfig
::
attention_output_bias
)
.
def_readwrite
(
"mlp_bias"
,
&
LlamaConfig
::
mlp_bias
)
.
def_readwrite
(
"tie_word_embeddings"
,
&
LlamaConfig
::
tie_word_embeddings
)
.
def_readwrite
(
"qk_norm"
,
&
LlamaConfig
::
qk_norm
)
.
def_readwrite
(
"use_cache"
,
&
LlamaConfig
::
use_cache
)
.
def_readwrite
(
"attention_dropout"
,
&
LlamaConfig
::
attention_dropout
)
.
def_readwrite
(
"initializer_range"
,
&
LlamaConfig
::
initializer_range
)
...
...
@@ -196,6 +197,7 @@ inline void bind_llama(py::module &m) {
dir_list
.
append
(
"attention_output_bias"
);
dir_list
.
append
(
"mlp_bias"
);
dir_list
.
append
(
"tie_word_embeddings"
);
dir_list
.
append
(
"qk_norm"
);
dir_list
.
append
(
"use_cache"
);
dir_list
.
append
(
"attention_dropout"
);
dir_list
.
append
(
"initializer_range"
);
...
...
python/infinilm/auto_config.py
View file @
571e0ba0
...
...
@@ -21,7 +21,9 @@ class AutoConfig:
if
config_dict
[
"model_type"
]
==
"llama"
:
return
LlamaConfig
(
**
config_dict
)
elif
config_dict
[
"model_type"
]
==
"qwen2"
:
elif
(
config_dict
[
"model_type"
]
==
"qwen2"
or
config_dict
[
"model_type"
]
==
"qwen3"
):
return
LlamaConfig
(
**
config_dict
)
raise
ValueError
(
f
"Unsupported model type `
{
config_dict
[
'model_type'
]
}
`."
)
python/infinilm/models/llama/configuration_llama.py
View file @
571e0ba0
...
...
@@ -186,6 +186,10 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
):
_infinilm
.
LlamaConfig
.
__init__
(
self
)
original_model_type
=
kwargs
.
get
(
"model_type"
,
None
)
if
original_model_type
==
"qwen3"
:
self
.
qk_norm
=
True
# ---
self
.
model_type
=
"llama"
self
.
name_or_path
=
""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment