Commit 571e0ba0 authored by wangpengcheng's avatar wangpengcheng
Browse files

issue/199 - 支持qwen3模型

parent c73ff203
......@@ -29,6 +29,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
kv_dim_(config.kv_dim()),
use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias),
use_qk_norm_(config.qk_norm),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
const auto &dtype{config.dtype};
......@@ -50,8 +51,14 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
dtype, device, rank_info);
// Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
// Initialize qk RMSNorm
if (use_qk_norm_) {
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, config.rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, config.rms_norm_eps, dtype, device);
}
}
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
......@@ -68,6 +75,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
// 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
if (use_qk_norm_) {
q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
}
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
......@@ -172,6 +184,11 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
if (use_qk_norm_) {
q_reshaped = q_norm_->forward(q_reshaped);
k_reshaped = k_norm_->forward(k_reshaped);
}
// 3. Prepare position_ids for RoPE - align with Python pattern
auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids;
......
......@@ -7,6 +7,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
#include "llama_config.hpp"
......@@ -92,7 +93,8 @@ protected:
// Projection layers
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm);
engine::distributed::RankInfo rank_info_;
// Shared Rotary Position Embeddings (RoPE)
......@@ -107,6 +109,7 @@ private:
size_t kv_dim_;
bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj)
bool use_qk_norm_; // Whether to use QK RMSNorm
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_;
......
......@@ -51,6 +51,7 @@ struct LlamaConfig : public InfinilmModel::Config {
bool attention_output_bias = false; // Whether to use bias in output projection (o_proj)
bool mlp_bias = false; // Whether to use bias in MLP projections
bool tie_word_embeddings = false; // Whether to tie input/output embeddings
bool qk_norm = false; // Whether to use QK RMSNorm
// Training/initialization parameters
double attention_dropout = 0.0; // Dropout ratio for attention probabilities
......
......@@ -64,6 +64,7 @@ inline void bind_llama(py::module &m) {
.def_readwrite("attention_output_bias", &LlamaConfig::attention_output_bias)
.def_readwrite("mlp_bias", &LlamaConfig::mlp_bias)
.def_readwrite("tie_word_embeddings", &LlamaConfig::tie_word_embeddings)
.def_readwrite("qk_norm", &LlamaConfig::qk_norm)
.def_readwrite("use_cache", &LlamaConfig::use_cache)
.def_readwrite("attention_dropout", &LlamaConfig::attention_dropout)
.def_readwrite("initializer_range", &LlamaConfig::initializer_range)
......@@ -196,6 +197,7 @@ inline void bind_llama(py::module &m) {
dir_list.append("attention_output_bias");
dir_list.append("mlp_bias");
dir_list.append("tie_word_embeddings");
dir_list.append("qk_norm");
dir_list.append("use_cache");
dir_list.append("attention_dropout");
dir_list.append("initializer_range");
......
......@@ -21,7 +21,9 @@ class AutoConfig:
if config_dict["model_type"] == "llama":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "qwen2":
elif (
config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3"
):
return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
......@@ -186,6 +186,10 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
):
_infinilm.LlamaConfig.__init__(self)
original_model_type = kwargs.get("model_type", None)
if original_model_type == "qwen3":
self.qk_norm = True
# ---
self.model_type = "llama"
self.name_or_path = ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment