Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
e60985dc
Unverified
Commit
e60985dc
authored
Mar 09, 2026
by
thatPepe
Committed by
GitHub
Mar 09, 2026
Browse files
Merge pull request #1040 from InfiniTensor/Issue/1030
Issue/1030: Nvidia 支持w4a16推理
parents
58771213
63233f9b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
144 additions
and
10 deletions
+144
-10
include/infinicore/quantization/awq.hpp
include/infinicore/quantization/awq.hpp
+12
-1
include/infinicore/quantization/base_quantization.hpp
include/infinicore/quantization/base_quantization.hpp
+25
-1
src/infinicore/nn/linear.cc
src/infinicore/nn/linear.cc
+95
-0
src/infinicore/ops/linear_w4a16_awq/linear_w4a16_awq.cc
src/infinicore/ops/linear_w4a16_awq/linear_w4a16_awq.cc
+12
-8
No files found.
include/infinicore/quantization/awq.hpp
View file @
e60985dc
...
@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
...
@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
// information and support multiple quantization schemes.
// information and support multiple quantization schemes.
public:
public:
explicit
AWQ
(
const
nlohmann
::
json
&
quant_config
)
explicit
AWQ
(
const
nlohmann
::
json
&
quant_config
)
:
BaseQuantization
(
quant_config
)
{};
:
BaseQuantization
(
quant_config
){};
infinicore
::
quantization
::
QuantScheme
infinicore
::
quantization
::
QuantScheme
get_quant_scheme
()
const
override
{
get_quant_scheme
()
const
override
{
return
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
;
return
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
;
};
};
int
get_packing_num
()
const
{
// For AWQ, we pack 8 int4 weights into a single int32 value.
return
32
/
this
->
get_or
<
int
>
(
"bits"
,
4
);
// Default to 8 if not specified in config
}
int
get_group_size
()
const
{
// For simplicity, we return a fixed group size here. In a more complete implementation,
// this could be extracted from quant_config_ to support different group sizes.
return
this
->
get_or
<
int
>
(
"group_size"
,
128
);
// Standard AWQ group size
}
};
};
}
// namespace infinicore::quantization
}
// namespace infinicore::quantization
include/infinicore/quantization/base_quantization.hpp
View file @
e60985dc
...
@@ -6,10 +6,34 @@ namespace infinicore::quantization {
...
@@ -6,10 +6,34 @@ namespace infinicore::quantization {
class
BaseQuantization
{
class
BaseQuantization
{
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
public:
public:
explicit
BaseQuantization
(
const
nlohmann
::
json
&
quant_config
)
:
quant_config_
(
quant_config
)
{};
explicit
BaseQuantization
(
const
nlohmann
::
json
&
quant_config
)
:
quant_config_
(
quant_config
){};
virtual
~
BaseQuantization
()
=
default
;
virtual
~
BaseQuantization
()
=
default
;
virtual
infinicore
::
quantization
::
QuantScheme
get_quant_scheme
()
const
=
0
;
virtual
infinicore
::
quantization
::
QuantScheme
get_quant_scheme
()
const
=
0
;
template
<
typename
T
>
T
get
(
const
std
::
string
&
key
)
const
{
if
(
!
quant_config_
.
contains
(
key
))
{
throw
std
::
out_of_range
(
"Key '"
+
key
+
"' not found in config."
);
}
try
{
return
quant_config_
.
at
(
key
).
get
<
T
>
();
}
catch
(
const
nlohmann
::
json
::
type_error
&
e
)
{
throw
std
::
runtime_error
(
"Type conversion failed for key '"
+
key
+
"': "
+
std
::
string
(
e
.
what
()));
}
}
template
<
typename
T
>
T
get_or
(
const
std
::
string
&
key
,
const
T
&
default_value
)
const
{
if
(
!
quant_config_
.
contains
(
key
)
||
quant_config_
.
at
(
key
).
is_null
())
{
return
default_value
;
}
try
{
return
quant_config_
.
at
(
key
).
get
<
T
>
();
}
catch
(
const
nlohmann
::
json
::
type_error
&
)
{
// If type conversion fails, return default value
return
default_value
;
}
}
protected:
protected:
nlohmann
::
json
quant_config_
;
nlohmann
::
json
quant_config_
;
...
...
src/infinicore/nn/linear.cc
View file @
e60985dc
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp"
#include <optional>
#include <optional>
#include <spdlog/spdlog.h>
#include <spdlog/spdlog.h>
...
@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
...
@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
auto
output
=
infinicore
::
op
::
linear_w8a8i8
(
input_contiguous
->
contiguous
(),
weight_packed_tensor
,
weight_scale_tensor
,
bias_opt
);
auto
output
=
infinicore
::
op
::
linear_w8a8i8
(
input_contiguous
->
contiguous
(),
weight_packed_tensor
,
weight_scale_tensor
,
bias_opt
);
return
output
;
return
output
;
}
}
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
{
Tensor
input_contiguous
=
input
->
is_contiguous
()
?
input
:
input
->
contiguous
();
Tensor
qweight
=
static_cast
<
const
Tensor
&>
(
weight_
);
Tensor
qzeros
=
static_cast
<
const
Tensor
&>
(
weight_zeros_
);
Tensor
scales
=
static_cast
<
const
Tensor
&>
(
weight_scale_
);
std
::
optional
<
Tensor
>
bias_opt
=
has_bias_
?
std
::
make_optional
<
Tensor
>
(
static_cast
<
const
Tensor
&>
(
bias_
))
:
std
::
nullopt
;
auto
output
=
infinicore
::
op
::
linear_w4a16_awq
(
input_contiguous
->
contiguous
(),
qweight
,
scales
,
qzeros
,
bias_opt
);
return
output
;
}
default:
{
default:
{
// Ensure input is contiguous before creating views (required for matmul)
// Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout
// This prevents hanging when input tensor has non-contiguous memory layout
...
@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
...
@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
}
}
break
;
break
;
}
}
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
{
weight_
=
infinicore
::
nn
::
Parameter
({
out_features
,
in_features
},
infinicore
::
DataType
::
I32
,
device
);
this
->
register_parameter
(
"qweight"
,
weight_
);
weight_zeros_
=
infinicore
::
nn
::
Parameter
({
out_features
,
in_features
},
infinicore
::
DataType
::
I32
,
device
);
this
->
register_parameter
(
"qzeros"
,
weight_zeros_
);
weight_scale_
=
infinicore
::
nn
::
Parameter
({
out_features
,
in_features
},
dtype_
,
device
);
this
->
register_parameter
(
"scales"
,
weight_scale_
);
if
(
bias
)
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
));
}
else
{
bias_
=
Parameter
();
}
break
;
}
default:
{
default:
{
// Initialize parameters using macro
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
));
...
@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
...
@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
}
}
break
;
break
;
}
}
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
{
auto
awq_ptr
=
std
::
static_pointer_cast
<
infinicore
::
quantization
::
AWQ
>
(
this
->
quantization_
);
int
group_size
=
awq_ptr
->
get_group_size
();
int
packing_num
=
awq_ptr
->
get_packing_num
();
weight_
=
infinicore
::
nn
::
Parameter
({
in_features
,
out_features
/
packing_num
},
infinicore
::
DataType
::
I32
,
device
,
1
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"qweight"
,
weight_
);
// Weight scale: [out_features, in_features / group_size]
// One FP32 scale per group of weights (group_size=128)
weight_scale_
=
infinicore
::
nn
::
Parameter
({
in_features
/
group_size
,
out_features
},
dtype_
,
device
,
1
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"scales"
,
weight_scale_
);
// Weight zeros (zero points): [out_features, in_features / group_size]
// AWQ implementations (e.g., AutoAWQ) typically store zero points as I32
// for symmetric/asymmetric quantization support
weight_zeros_
=
infinicore
::
nn
::
Parameter
({
in_features
/
group_size
,
out_features
/
packing_num
},
infinicore
::
DataType
::
I32
,
device
,
1
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"qzeros"
,
weight_zeros_
);
if
(
bias
)
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
,
0
,
0
,
1
));
}
else
{
bias_
=
Parameter
();
}
break
;
}
default:
{
default:
{
// Initialize parameters using macro
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
...
@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
...
@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
}
}
break
;
break
;
}
}
case
infinicore
::
quantization
::
QuantScheme
::
AWQ_W4A16
:
{
// AWQ W4A16 for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
// - Weight: packed int4 in I32 containers (8 int4 per I32)
// - Group-wise quantization with group_size=128
// - Scale and zero points stored per group along in_features dimension
auto
awq_ptr
=
std
::
static_pointer_cast
<
infinicore
::
quantization
::
AWQ
>
(
this
->
quantization_
);
int
group_size
=
awq_ptr
->
get_group_size
();
int
packing_num
=
awq_ptr
->
get_packing_num
();
// Packed weight: [out_features, in_features / 8]
weight_
=
infinicore
::
nn
::
Parameter
({
in_features
,
out_features
/
packing_num
},
infinicore
::
DataType
::
I32
,
device
,
0
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"qweight"
,
weight_
);
// Weight scale: [out_features, in_features / group_size]
weight_scale_
=
infinicore
::
nn
::
Parameter
({
in_features
/
group_size
,
out_features
},
dtype_
,
device
,
0
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"scales"
,
weight_scale_
);
// Weight zeros (zero points): [out_features, in_features / group_size]
weight_zeros_
=
infinicore
::
nn
::
Parameter
({
in_features
/
group_size
,
out_features
/
packing_num
},
infinicore
::
DataType
::
I32
,
device
,
0
,
tp_rank_
,
tp_size_
);
this
->
register_parameter
(
"qzeros"
,
weight_zeros_
);
// Bias handling in RowParallelLinear:
// - Only rank 0 holds the full bias (after all-reduce on output)
// - Other ranks have empty bias parameter
if
(
bias
&&
(
0
==
tp_rank_
))
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
,
0
,
0
,
1
));
}
else
{
bias_
=
Parameter
();
}
break
;
}
default:
{
default:
{
// Initialize parameters using macro
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
...
...
src/infinicore/ops/linear_w4a16_awq/linear_w4a16_awq.cc
View file @
e60985dc
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/rearrange.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
Tensor
linear_w4a16_awq
(
Tensor
input
,
Tensor
linear_w4a16_awq
(
Tensor
input
,
...
@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,
...
@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size
ndim
=
input
->
ndim
();
Size
ndim
=
input
->
ndim
();
Size
out_features
=
weight_packed
->
shape
()[
0
];
Size
element_size
=
weight_packed
->
element_size
();
Size
out_features
=
weight_packed
->
shape
()[
1
]
*
element_size
*
2
;
// Assign memory to out variables
// Assign memory to out variables
auto
output_shape
=
input
->
shape
();
auto
output_shape
=
input
->
shape
();
...
@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,
...
@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,
auto
weight_packed_shape
=
weight_packed
->
shape
();
auto
weight_packed_shape
=
weight_packed
->
shape
();
Size
out_features
=
weight_packed_shape
[
0
];
Size
out_features
=
weight_packed_shape
[
0
];
Size
in_features
=
weight_packed_shape
[
1
];
Size
in_features
=
weight_packed_shape
[
1
]
*
8
;
Size
ndim
=
input
->
ndim
();
Size
ndim
=
input
->
ndim
();
assert
(
out
->
ndim
()
==
ndim
);
assert
(
out
->
ndim
()
==
ndim
);
...
@@ -43,7 +44,6 @@ void linear_w4a16_awq_(Tensor out,
...
@@ -43,7 +44,6 @@ void linear_w4a16_awq_(Tensor out,
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
N
*=
input_shape
[
i
];
N
*=
input_shape
[
i
];
}
}
auto
weight
=
Tensor
::
empty
(
auto
weight
=
Tensor
::
empty
(
{
out_features
,
in_features
},
{
out_features
,
in_features
},
out
->
dtype
(),
out
->
dtype
(),
...
@@ -51,10 +51,14 @@ void linear_w4a16_awq_(Tensor out,
...
@@ -51,10 +51,14 @@ void linear_w4a16_awq_(Tensor out,
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
op
::
dequantize_awq_
(
weight
,
weight_packed
,
weight_scale
,
weight_zeros
);
op
::
dequantize_awq_
(
weight
,
weight_packed
,
weight_scale
,
weight_zeros
);
bias
=
std
::
make_optional
(
bias
.
value
()
->
as_strided
({
N
,
out_features
},
{
0
,
1
}));
if
(
bias
.
has_value
())
{
gemm_
(
out
->
view
({
N
,
out_features
}),
rearrange_
(
out
,
input
->
view
({
N
,
in_features
}),
bias
.
value
()
->
as_strided
({
N
,
in_features
},
{
0
,
1
}));
weight
->
permute
({
1
,
0
}),
alpha
,
beta
);
beta
=
1.0
f
;
}
gemm_
(
out
,
input
->
view
({
N
,
out_features
}),
weight
,
alpha
,
beta
);
}
}
}
// namespace infinicore::op
}
// namespace infinicore::op
yanzy
@yanzy
mentioned in commit
def22a08
·
Apr 21, 2026
mentioned in commit
def22a08
mentioned in commit def22a08ee170cba3fafe2b04d8ad2ce728b07ed
Toggle commit list
yanzy
@yanzy
mentioned in commit
cb7f0b7d
·
Apr 21, 2026
mentioned in commit
cb7f0b7d
mentioned in commit cb7f0b7d282d652a37b810abb06f6c1d4250bc78
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment