"vscode:/vscode.git/clone" did not exist on "15cbe5f70adaade1a8a11afc37601fc6606e7e0d"
Commit fc97bbd8 authored by qinyiqun's avatar qinyiqun
Browse files

Fix: Add lifecycle management to AWQ linear function

parent f692d681
...@@ -207,7 +207,7 @@ private: ...@@ -207,7 +207,7 @@ private:
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \ #define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \ name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_); \ auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(name##_->get_quantization()); \
int packing_num = awq_ptr->get_packing_num(); \ int packing_num = awq_ptr->get_packing_num(); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \ this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \ this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \
......
...@@ -112,12 +112,13 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo ...@@ -112,12 +112,13 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
break; break;
case infinicore::quantization::QuantScheme::AWQ_W4A16: case infinicore::quantization::QuantScheme::AWQ_W4A16: {
INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info); dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
break; break;
}
default: default:
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info); dtype, device, rank_info);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment