Unverified Commit 3747f7f3 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #200 from pengcheng888/issue/199

issue/199 - 支持qwen3模型
parents c73ff203 8b6fb721
...@@ -29,6 +29,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -29,6 +29,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
kv_dim_(config.kv_dim()), kv_dim_(config.kv_dim()),
use_bias_(config.attention_bias), use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias), use_output_bias_(config.attention_output_bias),
use_qk_norm_(config.qk_norm),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
...@@ -50,8 +51,14 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -50,8 +51,14 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
dtype, device, rank_info); dtype, device, rank_info);
// Output projection uses attention_output_bias (can be different from qkv) // Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_, INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
// Initialize qk RMSNorm
if (use_qk_norm_) {
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, config.rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, config.rms_norm_eps, dtype, device);
}
} }
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
...@@ -68,6 +75,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -68,6 +75,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
// 1. Project Q, K, V // 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
if (use_qk_norm_) {
q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
}
// 2. Reshape for multi-head attention // 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension // Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape) // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
...@@ -172,6 +184,11 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -172,6 +184,11 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
if (use_qk_norm_) {
q_reshaped = q_norm_->forward(q_reshaped);
k_reshaped = k_norm_->forward(k_reshaped);
}
// 3. Prepare position_ids for RoPE - align with Python pattern // 3. Prepare position_ids for RoPE - align with Python pattern
auto pos_shape = position_ids->shape(); auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids; infinicore::Tensor pos_ids_for_rope = position_ids;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "llama_config.hpp" #include "llama_config.hpp"
...@@ -92,7 +93,8 @@ protected: ...@@ -92,7 +93,8 @@ protected:
// Projection layers // Projection layers
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj); INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm);
engine::distributed::RankInfo rank_info_; engine::distributed::RankInfo rank_info_;
// Shared Rotary Position Embeddings (RoPE) // Shared Rotary Position Embeddings (RoPE)
...@@ -107,6 +109,7 @@ private: ...@@ -107,6 +109,7 @@ private:
size_t kv_dim_; size_t kv_dim_;
bool use_bias_; // Bias for Q/K/V projections bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj) bool use_output_bias_; // Bias for output projection (o_proj)
bool use_qk_norm_; // Whether to use QK RMSNorm
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_; float scaling_;
......
...@@ -51,6 +51,7 @@ struct LlamaConfig : public InfinilmModel::Config { ...@@ -51,6 +51,7 @@ struct LlamaConfig : public InfinilmModel::Config {
bool attention_output_bias = false; // Whether to use bias in output projection (o_proj) bool attention_output_bias = false; // Whether to use bias in output projection (o_proj)
bool mlp_bias = false; // Whether to use bias in MLP projections bool mlp_bias = false; // Whether to use bias in MLP projections
bool tie_word_embeddings = false; // Whether to tie input/output embeddings bool tie_word_embeddings = false; // Whether to tie input/output embeddings
bool qk_norm = false; // Whether to use QK RMSNorm
// Training/initialization parameters // Training/initialization parameters
double attention_dropout = 0.0; // Dropout ratio for attention probabilities double attention_dropout = 0.0; // Dropout ratio for attention probabilities
......
...@@ -64,6 +64,7 @@ inline void bind_llama(py::module &m) { ...@@ -64,6 +64,7 @@ inline void bind_llama(py::module &m) {
.def_readwrite("attention_output_bias", &LlamaConfig::attention_output_bias) .def_readwrite("attention_output_bias", &LlamaConfig::attention_output_bias)
.def_readwrite("mlp_bias", &LlamaConfig::mlp_bias) .def_readwrite("mlp_bias", &LlamaConfig::mlp_bias)
.def_readwrite("tie_word_embeddings", &LlamaConfig::tie_word_embeddings) .def_readwrite("tie_word_embeddings", &LlamaConfig::tie_word_embeddings)
.def_readwrite("qk_norm", &LlamaConfig::qk_norm)
.def_readwrite("use_cache", &LlamaConfig::use_cache) .def_readwrite("use_cache", &LlamaConfig::use_cache)
.def_readwrite("attention_dropout", &LlamaConfig::attention_dropout) .def_readwrite("attention_dropout", &LlamaConfig::attention_dropout)
.def_readwrite("initializer_range", &LlamaConfig::initializer_range) .def_readwrite("initializer_range", &LlamaConfig::initializer_range)
...@@ -196,6 +197,7 @@ inline void bind_llama(py::module &m) { ...@@ -196,6 +197,7 @@ inline void bind_llama(py::module &m) {
dir_list.append("attention_output_bias"); dir_list.append("attention_output_bias");
dir_list.append("mlp_bias"); dir_list.append("mlp_bias");
dir_list.append("tie_word_embeddings"); dir_list.append("tie_word_embeddings");
dir_list.append("qk_norm");
dir_list.append("use_cache"); dir_list.append("use_cache");
dir_list.append("attention_dropout"); dir_list.append("attention_dropout");
dir_list.append("initializer_range"); dir_list.append("initializer_range");
......
...@@ -21,7 +21,9 @@ class AutoConfig: ...@@ -21,7 +21,9 @@ class AutoConfig:
if config_dict["model_type"] == "llama": if config_dict["model_type"] == "llama":
return LlamaConfig(**config_dict) return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "qwen2": elif (
config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3"
):
return LlamaConfig(**config_dict) return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
...@@ -223,11 +223,11 @@ class InferEngine(_infinilm.InferEngine): ...@@ -223,11 +223,11 @@ class InferEngine(_infinilm.InferEngine):
f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n" f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n"
) )
print( print(
f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n", f" Prefill TTFT: {round(time_measurements[0] * 1000, 2)} ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)} tok/s\n",
) )
if len(time_measurements) > 1: if len(time_measurements) > 1:
print( print(
f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)} ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)} tok/s\n",
) )
return output_ids return output_ids
......
...@@ -186,6 +186,10 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig): ...@@ -186,6 +186,10 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
): ):
_infinilm.LlamaConfig.__init__(self) _infinilm.LlamaConfig.__init__(self)
original_model_type = kwargs.get("model_type", None)
if original_model_type == "qwen3":
self.qk_norm = True
# --- # ---
self.model_type = "llama" self.model_type = "llama"
self.name_or_path = "" self.name_or_path = ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment