Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
5fb56f97
Unverified
Commit
5fb56f97
authored
Mar 11, 2026
by
thatPepe
Committed by
GitHub
Mar 11, 2026
Browse files
Merge pull request #249 from InfiniTensor/Issue/243
Issue/243:支持w4a16 awq fp16推理
parents
a256e8d9
fc97bbd8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
28 deletions
+135
-28
csrc/layers/fused_linear.cpp
csrc/layers/fused_linear.cpp
+77
-0
csrc/layers/fused_linear.hpp
csrc/layers/fused_linear.hpp
+56
-27
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+2
-1
No files found.
csrc/layers/fused_linear.cpp
View file @
5fb56f97
...
...
@@ -170,6 +170,58 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
0
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_q_weight_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_
->
narrow
({{
1
,
0
,
q_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_k_weight_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_
->
narrow
({{
1
,
q_out_size_
/
scaling_factor
,
k_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_v_weight_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_
->
narrow
({{
1
,
(
q_out_size_
+
k_out_size_
)
/
scaling_factor
,
v_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_q_weight_scale_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_scale_
->
narrow
({{
1
,
0
,
q_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_k_weight_scale_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_scale_
->
narrow
({{
1
,
q_out_size_
/
scaling_factor
,
k_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_v_weight_scale_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_scale_
->
narrow
({{
1
,
(
q_out_size_
+
k_out_size_
)
/
scaling_factor
,
v_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_q_weight_zeros_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
1
,
0
,
q_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_k_weight_zeros_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
1
,
q_out_size_
/
scaling_factor
,
k_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_v_weight_zeros_awq
(
int
scaling_factor
)
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
1
,
(
q_out_size_
+
k_out_size_
)
/
scaling_factor
,
v_out_size_
/
scaling_factor
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
QKVParallelLinear
::
get_q_weight_zeros
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
0
,
0
,
q_out_size_
}}),
0
,
tp_rank_
,
tp_size_
);
...
...
@@ -320,4 +372,29 @@ bool GateUpParallelLinear::has_gate_bias() const {
bool
GateUpParallelLinear
::
has_up_bias
()
const
{
return
up_bias_
;
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_gate_weight_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_
->
narrow
({{
1
,
0
,
weight_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_up_weight_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_
->
narrow
({{
1
,
weight_
->
size
(
1
)
/
2
,
weight_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_gate_weight_scale_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_scale_
->
narrow
({{
1
,
0
,
weight_scale_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_up_weight_scale_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_scale_
->
narrow
({{
1
,
weight_scale_
->
size
(
1
)
/
2
,
weight_scale_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_gate_weight_zeros_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
1
,
0
,
weight_zeros_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
infinicore
::
nn
::
Parameter
GateUpParallelLinear
::
get_up_weight_zeros_awq
()
const
{
return
infinicore
::
nn
::
Parameter
(
weight_zeros_
->
narrow
({{
1
,
weight_zeros_
->
size
(
1
)
/
2
,
weight_zeros_
->
size
(
1
)
/
2
}}),
1
,
tp_rank_
,
tp_size_
);
}
}
// namespace infinilm::layers
csrc/layers/fused_linear.hpp
View file @
5fb56f97
...
...
@@ -58,6 +58,21 @@ public:
infinicore
::
nn
::
Parameter
get_k_weight_zeros
()
const
;
infinicore
::
nn
::
Parameter
get_v_weight_zeros
()
const
;
// For computing the packing factor in awq quantization:
// Returns the number of low-bit elements packed into a single high-bit container element.
// For example: int4 → int32 yields a packing factor of 8 (32 bits / 4 bits = 8 int4 values per int32).
infinicore
::
nn
::
Parameter
get_q_weight_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_k_weight_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_v_weight_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_q_weight_scale_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_k_weight_scale_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_v_weight_scale_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_q_weight_zeros_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_k_weight_zeros_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_v_weight_zeros_awq
(
int
scaling_factor
)
const
;
infinicore
::
nn
::
Parameter
get_q_bias
()
const
;
infinicore
::
nn
::
Parameter
get_k_bias
()
const
;
infinicore
::
nn
::
Parameter
get_v_bias
()
const
;
...
...
@@ -132,6 +147,18 @@ public:
infinicore
::
nn
::
Parameter
get_up_bias
()
const
;
infinicore
::
nn
::
Parameter
get_gate_weight_awq
()
const
;
infinicore
::
nn
::
Parameter
get_up_weight_awq
()
const
;
infinicore
::
nn
::
Parameter
get_up_weight_scale_awq
()
const
;
infinicore
::
nn
::
Parameter
get_up_weight_zeros_awq
()
const
;
infinicore
::
nn
::
Parameter
get_gate_weight_scale_awq
()
const
;
infinicore
::
nn
::
Parameter
get_gate_weight_zeros_awq
()
const
;
bool
has_gate_bias
()
const
;
bool
has_up_bias
()
const
;
...
...
@@ -180,15 +207,17 @@ private:
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(name##_->get_quantization()); \
int packing_num = awq_ptr->get_packing_num(); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale_awq(1)); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale_awq(1)); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros_awq(packing_num)); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale_awq(1)); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
...
...
@@ -210,12 +239,12 @@ private:
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".
scale
s", name##_->get_gate_weight_
scale
()); \
this->register_parameter(std::string(gate_name) + ".
qzero
s", name##_->get_gate_weight_
zeros
()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".
scale
s", name##_->get_up_weight_
scale
()); \
this->register_parameter(std::string(up_name) + ".
qzero
s", name##_->get_up_weight_
zeros
()); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight
_awq
()); \
this->register_parameter(std::string(gate_name) + ".
qzero
s", name##_->get_gate_weight_
zeros_awq
()); \
this->register_parameter(std::string(gate_name) + ".
scale
s", name##_->get_gate_weight_
scale_awq
()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight
_awq
()); \
this->register_parameter(std::string(up_name) + ".
qzero
s", name##_->get_up_weight_
zeros_awq
()); \
this->register_parameter(std::string(up_name) + ".
scale
s", name##_->get_up_weight_
scale_awq
()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
...
...
csrc/models/llama/llama_attention.cpp
View file @
5fb56f97
...
...
@@ -119,12 +119,13 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
break
;
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
{
INFINILM_QKV_LINEAR_W4A16AWQ_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
model_config_
->
get
<
size_t
>
(
"num_attention_heads"
),
model_config_
->
get
<
size_t
>
(
"num_key_value_heads"
),
this
->
model_config_
->
get_quantization_method
(),
use_bias_
,
dtype
,
device
,
rank_info
);
INFINICORE_NN_MODULE_INIT
(
o_proj
,
model_config_
->
get
<
size_t
>
(
"num_attention_heads"
)
*
head_dim_
,
hidden_size_
,
this
->
model_config_
->
get_quantization_method
(),
use_output_bias_
,
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
break
;
}
default:
INFINILM_QKV_LINEAR_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
model_config_
->
get
<
size_t
>
(
"num_attention_heads"
),
model_config_
->
get
<
size_t
>
(
"num_key_value_heads"
),
this
->
model_config_
->
get_quantization_method
(),
use_bias_
,
dtype
,
device
,
rank_info
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment