Commit f692d681 authored by qinyiqun's avatar qinyiqun
Browse files

Issue/243:支持w4a16 awq fp16推理

parent e76bb324
......@@ -170,6 +170,58 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, 0, q_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, 0, q_out_size_ / scaling_factor}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, q_out_size_ / scaling_factor, k_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros_awq(int scaling_factor) const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{1, (q_out_size_ + k_out_size_) / scaling_factor, v_out_size_ / scaling_factor}}),
1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
......@@ -320,4 +372,29 @@ bool GateUpParallelLinear::has_gate_bias() const {
bool GateUpParallelLinear::has_up_bias() const {
return up_bias_;
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_awq() const {
return infinicore::nn::Parameter(weight_->narrow({{1, 0, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_awq() const {
return infinicore::nn::Parameter(weight_->narrow({{1, weight_->size(1) / 2, weight_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale_awq() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{1, 0, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale_awq() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{1, weight_scale_->size(1) / 2, weight_scale_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros_awq() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{1, 0, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros_awq() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{1, weight_zeros_->size(1) / 2, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_);
}
} // namespace infinilm::layers
......@@ -58,6 +58,21 @@ public:
infinicore::nn::Parameter get_k_weight_zeros() const;
infinicore::nn::Parameter get_v_weight_zeros() const;
// For computing the packing factor in awq quantization:
// Returns the number of low-bit elements packed into a single high-bit container element.
// For example: int4 → int32 yields a packing factor of 8 (32 bits / 4 bits = 8 int4 values per int32).
infinicore::nn::Parameter get_q_weight_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_awq(int scaling_factor) const;
infinicore::nn::Parameter get_q_weight_scale_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_scale_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_scale_awq(int scaling_factor) const;
infinicore::nn::Parameter get_q_weight_zeros_awq(int scaling_factor) const;
infinicore::nn::Parameter get_k_weight_zeros_awq(int scaling_factor) const;
infinicore::nn::Parameter get_v_weight_zeros_awq(int scaling_factor) const;
infinicore::nn::Parameter get_q_bias() const;
infinicore::nn::Parameter get_k_bias() const;
infinicore::nn::Parameter get_v_bias() const;
......@@ -132,6 +147,18 @@ public:
infinicore::nn::Parameter get_up_bias() const;
infinicore::nn::Parameter get_gate_weight_awq() const;
infinicore::nn::Parameter get_up_weight_awq() const;
infinicore::nn::Parameter get_up_weight_scale_awq() const;
infinicore::nn::Parameter get_up_weight_zeros_awq() const;
infinicore::nn::Parameter get_gate_weight_scale_awq() const;
infinicore::nn::Parameter get_gate_weight_zeros_awq() const;
bool has_gate_bias() const;
bool has_up_bias() const;
......@@ -180,15 +207,17 @@ private:
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_); \
int packing_num = awq_ptr->get_packing_num(); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale_awq(1)); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale_awq(1)); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale_awq(1)); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
......@@ -210,12 +239,12 @@ private:
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros()); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight_awq()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros_awq()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale_awq()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight_awq()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros_awq()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale_awq()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment