Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f09daea2
Unverified
Commit
f09daea2
authored
Mar 31, 2026
by
Yintong Lu
Committed by
GitHub
Mar 31, 2026
Browse files
[CPU] Support int8 compute mode in CPU AWQ (#35697)
Signed-off-by:
Yintong Lu
<
yintong.lu@intel.com
>
parent
42318c84
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1197 additions
and
11 deletions
+1197
-11
.buildkite/hardware_tests/cpu.yaml
.buildkite/hardware_tests/cpu.yaml
+3
-1
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+1
-0
csrc/cpu/sgl-kernels/common.h
csrc/cpu/sgl-kernels/common.h
+8
-0
csrc/cpu/sgl-kernels/gemm.h
csrc/cpu/sgl-kernels/gemm.h
+36
-3
csrc/cpu/sgl-kernels/gemm_int4.cpp
csrc/cpu/sgl-kernels/gemm_int4.cpp
+755
-0
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+20
-0
tests/kernels/test_awq_int4_to_int8.py
tests/kernels/test_awq_int4_to_int8.py
+281
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+32
-0
vllm/envs.py
vllm/envs.py
+3
-0
vllm/model_executor/layers/quantization/cpu_wna16.py
vllm/model_executor/layers/quantization/cpu_wna16.py
+58
-7
No files found.
.buildkite/hardware_tests/cpu.yaml
View file @
f09daea2
...
...
@@ -13,12 +13,14 @@ steps:
-
tests/kernels/attention/test_cpu_attn.py
-
tests/kernels/moe/test_cpu_fused_moe.py
-
tests/kernels/test_onednn.py
-
tests/kernels/test_awq_int4_to_int8.py
commands
:
-
|
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py"
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
-
label
:
CPU-Compatibility Tests
depends_on
:
[]
...
...
cmake/cpu_extension.cmake
View file @
f09daea2
...
...
@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
)
...
...
csrc/cpu/sgl-kernels/common.h
View file @
f09daea2
...
...
@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
#endif
}
inline
int
get_thread_num
()
{
#if defined(_OPENMP)
return
omp_get_thread_num
();
#else
return
0
;
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int
inline
adjust_num_threads
(
int
m
)
{
...
...
csrc/cpu/sgl-kernels/gemm.h
View file @
f09daea2
...
...
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template
<
typename
T
>
inline
bool
can_use_brgemm
(
int
M
);
template
<
>
inline
bool
can_use_brgemm
<
at
::
BFloat16
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
return
true
;
}
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
false
;
}
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
u
int8_t
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
quint4x2
>
(
int
M
)
{
return
M
>
4
;
}
...
...
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return
use_int8_w8a8
?
K
+
sizeof
(
int32_t
)
:
K
;
}
// pack weight to vnni format
inline
int64_t
get_4bit_block_k_size
(
int64_t
group_size
)
{
return
group_size
>
128
?
128
:
group_size
;
}
// pack weight into vnni format
at
::
Tensor
convert_weight_packed
(
at
::
Tensor
&
weight
);
// pack weight to vnni format for int4 (adapted from sglang)
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_weight_packed_scale_zp
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
,
at
::
Tensor
scales
);
// moe implementations for int8 w8a8
template
<
typename
scalar_t
>
void
fused_experts_int8_kernel_impl
(
...
...
@@ -233,6 +241,31 @@ void tinygemm_kernel(
int64_t
strideBs
,
bool
brg
);
// int4 scaled GEMM (adapted from sglang)
at
::
Tensor
int4_scaled_mm_cpu
(
at
::
Tensor
&
x
,
at
::
Tensor
&
w
,
at
::
Tensor
&
w_zeros
,
at
::
Tensor
&
w_scales
,
std
::
optional
<
at
::
Tensor
>
bias
);
// int4 tinygemm kernel interface(adapted from sglang)
template
<
typename
scalar_t
>
void
tinygemm_kernel
(
scalar_t
*
C
,
float
*
C_temp
,
const
uint8_t
*
A
,
const
float
*
scales_a
,
const
int32_t
*
qzeros_a
,
const
uint8_t
*
B
,
const
float
*
scales_b
,
const
int8_t
*
qzeros_b
,
const
int32_t
*
compensation
,
int8_t
*
dqB_tmp
,
int64_t
M
,
int64_t
K
,
int64_t
lda
,
int64_t
ldc_f
,
int64_t
ldc_s
,
bool
store_out
,
bool
use_brgemm
);
// TODO: debug print, remove me later
inline
void
print_16x32i
(
const
__m512i
x
)
{
int32_t
a
[
16
];
...
...
csrc/cpu/sgl-kernels/gemm_int4.cpp
0 → 100644
View file @
f09daea2
// SPDX-License-Identifier: Apache-2.0
// Adapted from sgl-project/sglang
// https://github.com/sgl-project/sglang/pull/8226
#include <ATen/ATen.h>
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace
{
#define BLOCK_N block_size_n()
#define BLOCK_M 128
template
<
bool
sym_quant_act
>
struct
ActDtype
;
template
<
>
struct
ActDtype
<
true
>
{
using
type
=
int8_t
;
};
template
<
>
struct
ActDtype
<
false
>
{
using
type
=
uint8_t
;
};
struct
alignas
(
32
)
m256i_wrapper
{
__m256i
data
;
};
#if defined(CPU_CAPABILITY_AVX512)
inline
std
::
array
<
m256i_wrapper
,
2
>
load_zps_4vnni
(
const
int8_t
*
__restrict__
zps
)
{
__m256i
vzps_low
=
_mm256_set1_epi64x
(
*
reinterpret_cast
<
const
int64_t
*>
(
zps
));
__m256i
vzps_high
=
_mm256_set1_epi64x
(
*
reinterpret_cast
<
const
int64_t
*>
(
zps
+
8
));
__m256i
shuffle_mask
=
_mm256_set_epi8
(
7
,
7
,
7
,
7
,
6
,
6
,
6
,
6
,
5
,
5
,
5
,
5
,
4
,
4
,
4
,
4
,
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
);
vzps_low
=
_mm256_shuffle_epi8
(
vzps_low
,
shuffle_mask
);
vzps_high
=
_mm256_shuffle_epi8
(
vzps_high
,
shuffle_mask
);
m256i_wrapper
vzps_low_wp
,
vzps_high_wp
;
vzps_low_wp
.
data
=
vzps_low
;
vzps_high_wp
.
data
=
vzps_high
;
return
{
vzps_low_wp
,
vzps_high_wp
};
}
inline
std
::
array
<
m256i_wrapper
,
2
>
load_uint4_as_int8
(
const
uint8_t
*
__restrict__
qB
)
{
__m256i
packed
=
_mm256_loadu_si256
(
reinterpret_cast
<
const
__m256i
*>
(
qB
));
const
__m256i
low_mask
=
_mm256_set1_epi8
(
0x0f
);
__m256i
high
=
_mm256_srli_epi16
(
packed
,
4
);
high
=
_mm256_and_si256
(
high
,
low_mask
);
__m256i
low
=
_mm256_and_si256
(
packed
,
low_mask
);
m256i_wrapper
low_wp
,
high_wp
;
low_wp
.
data
=
low
;
high_wp
.
data
=
high
;
return
{
low_wp
,
high_wp
};
}
template
<
int
N
,
int
ldb
>
void
_dequant_weight_zp_only
(
const
uint8_t
*
__restrict__
B
,
int8_t
*
dqB
,
const
int8_t
*
__restrict__
qzeros
,
int64_t
K
)
{
#pragma GCC unroll 2
for
(
int
n
=
0
;
n
<
N
;
n
+=
16
)
{
auto
[
zps_low_wp
,
zps_high_wp
]
=
load_zps_4vnni
(
&
qzeros
[
n
]);
auto
zps_low
=
zps_low_wp
.
data
;
auto
zps_high
=
zps_high_wp
.
data
;
for
(
int
k
=
0
;
k
<
K
;
k
+=
4
)
{
auto
[
vb_low_wp
,
vb_high_wp
]
=
load_uint4_as_int8
(
B
+
ldb
*
k
+
n
/
2
*
4
);
auto
vb_low
=
vb_low_wp
.
data
;
auto
vb_high
=
vb_high_wp
.
data
;
vb_high
=
_mm256_sub_epi8
(
vb_high
,
zps_high
);
vb_low
=
_mm256_sub_epi8
(
vb_low
,
zps_low
);
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i_u
*>
(
dqB
+
N
*
k
+
n
*
4
),
vb_low
);
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i_u
*>
(
dqB
+
N
*
k
+
(
n
+
8
)
*
4
),
vb_high
);
}
}
}
template
<
bool
sym_quant_act
,
int
N
,
bool
accum
>
void
_dequant_and_store
(
float
*
__restrict__
output
,
const
int32_t
*
__restrict__
input
,
const
float
*
__restrict__
scale_a
,
const
int32_t
*
__restrict__
zp_a
,
const
float
*
__restrict__
scale_b
,
const
int32_t
*
__restrict__
comp_b
,
int
M
,
int
ldi
,
int
ldo
,
int
ldsa
=
1
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
a_scale
=
*
(
scale_a
+
m
*
ldsa
);
__m512
va_scale
=
_mm512_set1_ps
(
a_scale
);
int32_t
a_zp
;
__m512i
va_zp
;
if
constexpr
(
!
sym_quant_act
)
{
a_zp
=
*
(
zp_a
+
m
*
ldsa
);
va_zp
=
_mm512_set1_epi32
(
a_zp
);
}
int
n
=
0
;
#pragma GCC unroll 2
for
(;
n
<
N
;
n
+=
16
)
{
__m512i
vc
=
_mm512_loadu_si512
(
input
+
m
*
ldi
+
n
);
if
constexpr
(
!
sym_quant_act
)
{
__m512i
vb_comp
=
_mm512_loadu_si512
(
comp_b
+
n
);
vc
=
_mm512_sub_epi32
(
vc
,
_mm512_mullo_epi32
(
vb_comp
,
va_zp
));
}
__m512
vc_f
=
_mm512_cvtepi32_ps
(
vc
);
__m512
vc_f_mul
=
_mm512_mul_ps
(
vc_f
,
va_scale
);
__m512
vb_s
=
_mm512_loadu_ps
(
scale_b
+
n
);
vc_f_mul
=
_mm512_mul_ps
(
vc_f_mul
,
vb_s
);
if
constexpr
(
accum
)
{
__m512
vo
=
_mm512_loadu_ps
(
output
+
m
*
ldo
+
n
);
_mm512_storeu_ps
(
output
+
m
*
ldo
+
n
,
_mm512_add_ps
(
vo
,
vc_f_mul
));
}
else
{
_mm512_storeu_ps
(
output
+
m
*
ldo
+
n
,
vc_f_mul
);
}
}
for
(;
n
<
N
;
++
n
)
{
float
dq_val
;
if
constexpr
(
sym_quant_act
)
{
dq_val
=
(
float
)
input
[
m
*
ldi
+
n
]
*
a_scale
*
scale_b
[
n
];
}
else
{
dq_val
=
(
float
)(
input
[
m
*
ldi
+
n
]
-
a_zp
*
comp_b
[
n
])
*
a_scale
*
scale_b
[
n
];
}
if
constexpr
(
accum
)
{
output
[
m
*
ldo
+
n
]
+=
dq_val
;
}
else
{
output
[
m
*
ldo
+
n
]
=
dq_val
;
}
}
}
}
#else
template
<
int
N
,
int
ldb
>
void
_dequant_weight_zp_only
(
const
uint8_t
*
B
,
int8_t
*
dqB
,
const
int8_t
*
qzeros
,
int64_t
K
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
/
2
;
++
n
)
{
int32_t
b
=
(
int32_t
)
B
[
k
*
ldb
+
n
];
dqB
[
k
*
N
+
n
*
2
]
=
(
b
&
0xf
)
-
qzeros
[
n
];
dqB
[
k
*
N
+
n
*
2
+
1
]
=
(
b
>>
4
)
-
qzeros
[
n
];
}
}
}
#endif
#if defined(CPU_CAPABILITY_AVX512)
inline
__m512i
combine_m256i
(
__m256i
a
,
__m256i
b
)
{
__m512i
c
=
_mm512_castsi256_si512
(
a
);
return
_mm512_inserti64x4
(
c
,
b
,
1
);
}
inline
__m512i
combine_m256i
(
std
::
array
<
m256i_wrapper
,
2
>
two_256
)
{
return
combine_m256i
(
two_256
[
0
].
data
,
two_256
[
1
].
data
);
}
static
inline
__m512i
_mm512_sign_epi8
(
__m512i
a
,
__m512i
b
)
{
__m512i
zero
=
_mm512_setzero_si512
();
__mmask64
blt0
=
_mm512_movepi8_mask
(
b
);
return
_mm512_mask_sub_epi8
(
a
,
blt0
,
zero
,
a
);
}
template
<
bool
sym_quant_act
,
int
M
,
int
N
,
int
ldb
>
void
_dequant_gemm_accum_small_M
(
float
*
__restrict__
C
,
const
uint8_t
*
A
,
const
float
*
scales_a
,
const
int32_t
*
qzeros_a
,
const
uint8_t
*
B
,
const
float
*
scales_b
,
const
int8_t
*
qzeros_b
,
int64_t
K
,
int64_t
lda
,
int64_t
ldc
)
{
constexpr
int
COLS
=
N
/
16
;
__m512i
ones
=
_mm512_set1_epi8
(
1
);
__m512i
va
;
__m512i
vb
[
COLS
];
__m512i
vc
[
M
*
COLS
];
__m512
vscales
[
COLS
];
__m512i
vzps
[
COLS
];
__m512i
vcompensate
[
COLS
];
Unroll
<
COLS
>
{}([
&
](
auto
i
)
{
vscales
[
i
]
=
_mm512_loadu_ps
(
scales_b
+
i
*
16
);
vzps
[
i
]
=
combine_m256i
(
load_zps_4vnni
(
qzeros_b
+
i
*
16
));
if
constexpr
(
!
sym_quant_act
)
{
vcompensate
[
i
]
=
_mm512_setzero_epi32
();
}
});
Unroll
<
M
*
COLS
>
{}([
&
](
auto
i
)
{
vc
[
i
]
=
_mm512_setzero_epi32
();
});
auto
compute
=
[
&
](
auto
i
,
int
k
)
{
constexpr
const
int
row
=
i
/
COLS
;
constexpr
const
int
col
=
i
%
COLS
;
if
constexpr
(
col
==
0
)
{
va
=
_mm512_set1_epi32
(
*
(
int32_t
*
)(
A
+
row
*
lda
+
k
));
}
if
constexpr
(
row
==
0
)
{
int
B_offset
=
k
*
ldb
+
col
*
16
*
2
;
vb
[
col
]
=
combine_m256i
(
load_uint4_as_int8
(
B
+
B_offset
));
vb
[
col
]
=
_mm512_sub_epi8
(
vb
[
col
],
vzps
[
col
]);
if
constexpr
(
!
sym_quant_act
)
{
vcompensate
[
col
]
=
_mm512_dpbusd_epi32
(
vcompensate
[
col
],
ones
,
vb
[
col
]);
}
_mm_prefetch
(
B
+
B_offset
+
128
*
ldb
,
_MM_HINT_T0
);
}
if
constexpr
(
sym_quant_act
)
{
auto
vsb
=
_mm512_sign_epi8
(
vb
[
col
],
va
);
auto
vabsa
=
_mm512_sign_epi8
(
va
,
va
);
vc
[
i
]
=
_mm512_dpbusds_epi32
(
vc
[
i
],
vabsa
,
vsb
);
}
else
{
vc
[
i
]
=
_mm512_dpbusd_epi32
(
vc
[
i
],
va
,
vb
[
col
]);
}
};
constexpr
const
int
unroll
=
4
;
int
k
=
0
;
for
(;
k
<
K
/
4
/
unroll
;
k
++
)
{
Unroll
<
unroll
>
{}(
[
&
](
auto
i
)
{
Unroll
<
M
*
COLS
>
{}(
compute
,
4
*
(
k
*
unroll
+
i
));
});
}
k
*=
4
*
unroll
;
for
(;
k
<
K
;
k
+=
4
)
{
Unroll
<
M
*
COLS
>
{}(
compute
,
k
);
}
auto
store
=
[
&
](
auto
i
)
{
constexpr
const
int
row
=
i
/
COLS
;
constexpr
const
int
col
=
i
%
COLS
;
__m512
vc_float
;
if
constexpr
(
!
sym_quant_act
)
{
vc
[
i
]
=
_mm512_sub_epi32
(
vc
[
i
],
_mm512_mullo_epi32
(
vcompensate
[
col
],
_mm512_set1_epi32
(
*
(
qzeros_a
+
row
))));
}
vc_float
=
_mm512_cvtepi32_ps
(
vc
[
i
]);
vc_float
=
_mm512_mul_ps
(
vc_float
,
_mm512_set1_ps
(
*
(
scales_a
+
row
)));
vc_float
=
_mm512_mul_ps
(
vc_float
,
vscales
[
col
]);
auto
vc_old
=
_mm512_loadu_ps
(
C
+
row
*
ldc
+
col
*
16
);
vc_float
=
_mm512_add_ps
(
vc_float
,
vc_old
);
_mm512_storeu_ps
(
C
+
row
*
ldc
+
col
*
16
,
vc_float
);
};
Unroll
<
M
*
COLS
>
{}(
store
);
}
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
#endif
template
<
bool
sym_quant_act
,
int
N
,
int
ldb
>
void
_dequant_gemm_accum
(
float
*
C
,
const
uint8_t
*
A
,
const
float
*
scales_a
,
const
int32_t
*
qzeros_a
,
const
uint8_t
*
B
,
const
float
*
scales_b
,
const
int8_t
*
qzeros_b
,
const
int32_t
*
compensation
,
int8_t
*
dqB
,
int64_t
M
,
int64_t
K
,
int64_t
lda
,
int64_t
ldc
,
bool
use_brgemm
)
{
#if defined(CPU_CAPABILITY_AVX512)
if
(
!
use_brgemm
)
{
switch
(
M
)
{
case
1
:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M
(
1
);
break
;
case
2
:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M
(
2
);
break
;
case
3
:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M
(
3
);
break
;
case
4
:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M
(
4
);
break
;
default:
TORCH_CHECK
(
false
,
"tinygemm_kernel: unexpected M for AVX path!"
);
}
return
;
}
_dequant_weight_zp_only
<
N
,
ldb
>
(
B
,
dqB
,
qzeros_b
,
K
);
using
Tin
=
typename
ActDtype
<
sym_quant_act
>::
type
;
Tin
*
A_ptr
=
(
Tin
*
)
A
;
if
(
use_brgemm
)
{
int32_t
C_i32
[
M
*
N
];
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
N
/*ldb*/
,
N
/*ldc*/
,
false
/* add_C */
,
A_ptr
,
dqB
,
C_i32
,
true
/* is_vnni */
);
_mm_prefetch
(
B
+
N
*
K
/
2
,
_MM_HINT_T0
);
_mm_prefetch
(
A
+
K
,
_MM_HINT_T0
);
_dequant_and_store
<
sym_quant_act
,
N
,
true
>
(
C
,
C_i32
,
scales_a
,
qzeros_a
,
scales_b
,
compensation
,
M
,
N
/*ldi*/
,
ldc
,
1
/*ldsa*/
);
}
else
#endif
{
TORCH_CHECK
(
false
,
"tinygemm_kernel: scalar path not implemented!"
);
}
}
template
<
int
N
>
inline
void
copy_bias
(
const
float
*
bias_ptr
,
float
*
y_buf
,
int64_t
m
)
{
if
(
bias_ptr
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
int
j
=
0
;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for
(;
j
<
N
;
j
+=
16
)
{
__m512
bias_vec
=
_mm512_loadu_ps
(
bias_ptr
+
j
);
_mm512_storeu_ps
(
y_buf
+
i
*
N
+
j
,
bias_vec
);
}
#endif
for
(;
j
<
N
;
++
j
)
{
y_buf
[
i
*
N
+
j
]
=
bias_ptr
[
j
];
}
}
}
else
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
int
j
=
0
;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for
(;
j
<
N
;
j
+=
16
)
{
__m512
zero_vec
=
_mm512_setzero_ps
();
_mm512_storeu_ps
(
y_buf
+
i
*
N
+
j
,
zero_vec
);
}
#endif
for
(;
j
<
N
;
++
j
)
{
y_buf
[
i
*
N
+
j
]
=
0
;
}
}
}
}
template
<
int
N
,
typename
out_dtype
>
inline
void
store_out
(
const
float
*
y_buf
,
out_dtype
*
c_ptr
,
int64_t
m
,
int64_t
lda
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
int
j
=
0
;
if
constexpr
(
std
::
is_same
<
out_dtype
,
float
>::
value
)
{
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for
(;
j
<
N
;
j
+=
16
)
{
__m512
y_vec
=
_mm512_loadu_ps
(
y_buf
+
i
*
N
+
j
);
_mm512_storeu_ps
(
c_ptr
+
i
*
lda
+
j
,
y_vec
);
}
#endif
for
(;
j
<
N
;
++
j
)
{
c_ptr
[
i
*
lda
+
j
]
=
y_buf
[
i
*
N
+
j
];
}
}
else
if
constexpr
(
std
::
is_same
<
out_dtype
,
at
::
BFloat16
>::
value
)
{
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for
(;
j
<
N
;
j
+=
16
)
{
__m512
y_vec
=
_mm512_loadu_ps
(
y_buf
+
i
*
N
+
j
);
__m256i
y_bf16_vec
=
at
::
vec
::
cvtfp32_bf16
(
y_vec
);
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i
*>
(
c_ptr
+
i
*
lda
+
j
),
y_bf16_vec
);
}
#endif
for
(;
j
<
N
;
++
j
)
{
c_ptr
[
i
*
lda
+
j
]
=
at
::
BFloat16
(
y_buf
[
i
*
N
+
j
]);
}
}
else
if
constexpr
(
std
::
is_same
<
out_dtype
,
at
::
Half
>::
value
)
{
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for
(;
j
<
N
;
j
+=
16
)
{
__m512
y_vec
=
_mm512_loadu_ps
(
y_buf
+
i
*
N
+
j
);
__m256i
y_fp16_vec
=
at
::
vec
::
cvtfp32_fp16
(
y_vec
);
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i
*>
(
c_ptr
+
i
*
lda
+
j
),
y_fp16_vec
);
}
#endif
for
(;
j
<
N
;
++
j
)
{
c_ptr
[
i
*
lda
+
j
]
=
at
::
Half
(
y_buf
[
i
*
N
+
j
]);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output dtype"
);
}
}
}
void
fill_val_stub
(
int32_t
*
__restrict__
output
,
int32_t
value
,
int64_t
size
)
{
using
iVec
=
at
::
vec
::
Vectorized
<
int32_t
>
;
constexpr
int
VecSize
=
iVec
::
size
();
const
iVec
fill_val_vec
=
iVec
(
value
);
int64_t
d
;
#pragma GCC unroll 4
for
(
d
=
0
;
d
<=
size
-
VecSize
;
d
+=
VecSize
)
{
fill_val_vec
.
store
(
output
+
d
);
}
for
(;
d
<
size
;
++
d
)
{
output
[
d
]
=
value
;
}
}
template
<
bool
sym_quant_act
,
typename
act_dtype
,
typename
out_dtype
>
void
_da8w4_linear_impl
(
act_dtype
*
__restrict__
input
,
const
float
*
__restrict__
input_scales
,
const
int32_t
*
__restrict__
input_qzeros
,
const
uint8_t
*
__restrict__
weight
,
const
float
*
__restrict__
weight_scales
,
const
int8_t
*
__restrict__
weight_qzeros
,
const
float
*
__restrict__
bias
,
out_dtype
*
__restrict__
output
,
float
*
__restrict__
output_temp
,
int8_t
*
__restrict__
dequant_weight_temp
,
int64_t
M
,
int64_t
N
,
int64_t
K
,
int64_t
num_groups
)
{
const
bool
use_brgemm
=
can_use_brgemm
<
act_dtype
>
(
M
);
int64_t
block_m
=
[
&
]()
->
long
{
if
(
M
<=
48
)
{
return
M
;
}
else
if
(
M
<
64
)
{
return
32
;
}
else
if
(
M
<
96
)
{
return
64
;
}
else
{
return
128
;
}
}();
int64_t
Mc
=
div_up
(
M
,
block_m
);
bool
parallel_on_M
=
M
>
128
;
int64_t
Nc
=
N
/
BLOCK_N
;
int64_t
num_blocks
=
parallel_on_M
?
Mc
*
Nc
:
Nc
;
int64_t
group_size
=
div_up
(
K
,
num_groups
);
int64_t
_block_k
=
get_4bit_block_k_size
(
group_size
);
int64_t
Kc
=
K
/
_block_k
;
int64_t
block_per_group
=
group_size
/
_block_k
;
at
::
parallel_for
(
0
,
num_blocks
,
1
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int
tid
=
get_thread_num
();
float
*
C_tmp
=
output_temp
+
tid
*
block_m
*
BLOCK_N
;
int8_t
*
dqB_tmp
=
dequant_weight_temp
+
tid
*
_block_k
*
BLOCK_N
;
for
(
const
auto
i
:
c10
::
irange
(
begin
,
end
))
{
int64_t
mc
=
parallel_on_M
?
i
/
Nc
:
0
;
int64_t
nc
=
parallel_on_M
?
i
%
Nc
:
i
;
int64_t
mc_end
=
parallel_on_M
?
mc
+
1
:
Mc
;
for
(
int
mci
=
mc
;
mci
<
mc_end
;
++
mci
)
{
int64_t
m_size
=
mci
*
block_m
+
block_m
>
M
?
M
-
mci
*
block_m
:
block_m
;
auto
bias_data
=
bias
?
bias
+
nc
*
BLOCK_N
:
nullptr
;
copy_bias
<
BLOCK_N
>
(
bias_data
,
C_tmp
,
m_size
);
for
(
int
kci
=
0
;
kci
<
Kc
;
++
kci
)
{
int32_t
*
compensation_ptr
=
sym_quant_act
?
nullptr
:
(
int32_t
*
)(
void
*
)(
weight
+
(
nc
*
Kc
+
kci
)
*
(
BLOCK_N
*
(
_block_k
/
2
+
sizeof
(
int32_t
)))
+
_block_k
*
BLOCK_N
/
2
);
_dequant_gemm_accum
<
sym_quant_act
,
BLOCK_N
,
BLOCK_N
/
2
>
(
/*C*/
C_tmp
,
/*A*/
(
uint8_t
*
)
input
+
mci
*
block_m
*
K
+
kci
*
_block_k
,
/*scales_a*/
input_scales
+
mci
*
block_m
,
/*qzeros_a*/
input_qzeros
+
mci
*
block_m
,
/*B*/
weight
+
(
nc
*
Kc
+
kci
)
*
(
BLOCK_N
*
(
_block_k
/
2
+
sizeof
(
int32_t
))),
/*scales_b*/
weight_scales
+
nc
*
BLOCK_N
*
num_groups
+
kci
/
block_per_group
*
BLOCK_N
,
/*qzeros_b*/
weight_qzeros
+
nc
*
BLOCK_N
*
num_groups
+
kci
/
block_per_group
*
BLOCK_N
,
/*Bcomp*/
compensation_ptr
,
/*dqB_tmp*/
dqB_tmp
,
/*M*/
m_size
,
/*K*/
_block_k
,
/*lda*/
K
,
/*ldc*/
BLOCK_N
,
/*use_brgemm*/
use_brgemm
);
}
store_out
<
BLOCK_N
>
(
C_tmp
,
output
+
mci
*
block_m
*
N
+
nc
*
BLOCK_N
,
m_size
,
N
/*lda*/
);
}
}
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
}
}
// anonymous namespace
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_int4_weight_packed_with_compensation
(
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
scales
,
const
at
::
Tensor
&
qzeros
)
{
TORCH_CHECK
(
weight
.
dim
()
==
2
,
"DA8W4 CPU: Weight should be a 2D tensor for packing"
);
TORCH_CHECK
(
weight
.
size
(
1
)
%
2
==
0
,
"DA8W4 CPU: Weight should have even number of columns for packing"
);
auto
new_scales
=
scales
;
auto
new_qzeros
=
qzeros
;
if
(
new_scales
.
dim
()
==
1
)
{
new_scales
.
unsqueeze_
(
1
);
}
new_scales
=
new_scales
.
to
(
at
::
kFloat
);
if
(
new_qzeros
.
dim
()
==
1
)
{
new_qzeros
.
unsqueeze_
(
1
);
}
new_qzeros
=
new_qzeros
.
to
(
at
::
kChar
);
int64_t
N
=
weight
.
size
(
0
);
int64_t
K
=
weight
.
size
(
1
);
int64_t
G
=
scales
.
size
(
1
);
int64_t
group_size
=
K
/
G
;
int64_t
_block_k
=
get_4bit_block_k_size
(
group_size
);
constexpr
int
block_n
=
block_size_n
();
int64_t
Nc
=
N
/
block_n
;
int64_t
Kc
=
K
/
_block_k
;
auto
weight_view
=
weight
.
view
({
Nc
,
block_n
,
Kc
,
_block_k
});
at
::
Tensor
weight_reordered
=
weight_view
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
at
::
Tensor
blocked_weight
;
at
::
Tensor
blocked_scales
=
new_scales
.
view
({
Nc
,
block_n
,
G
}).
permute
({
0
,
2
,
1
}).
contiguous
();
at
::
Tensor
blocked_qzeros
=
new_qzeros
.
view
({
Nc
,
block_n
,
G
}).
permute
({
0
,
2
,
1
}).
contiguous
();
auto
weight_sub_qzero
=
weight
.
view
({
Nc
,
block_n
,
G
,
-
1
}).
to
(
at
::
kInt
)
-
new_qzeros
.
view
({
Nc
,
block_n
,
G
,
-
1
});
weight_sub_qzero
=
weight_sub_qzero
.
view
({
Nc
,
block_n
,
Kc
,
_block_k
});
at
::
Tensor
compensation
=
weight_sub_qzero
.
sum
(
-
1
);
compensation
=
compensation
.
permute
({
0
,
2
,
1
}).
contiguous
().
to
(
at
::
kInt
);
int64_t
buffer_size_nbytes
=
_block_k
*
block_n
/
2
+
block_n
*
sizeof
(
int32_t
);
blocked_weight
=
at
::
empty
({
Nc
,
Kc
,
buffer_size_nbytes
},
weight
.
options
());
auto
weight_ptr
=
weight_reordered
.
data_ptr
<
uint8_t
>
();
auto
compensation_ptr
=
compensation
.
data_ptr
<
int32_t
>
();
auto
blocked_weight_ptr
=
blocked_weight
.
data_ptr
<
uint8_t
>
();
int64_t
num_blocks
=
Nc
*
Kc
;
at
::
parallel_for
(
0
,
num_blocks
,
1
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
const
auto
i
:
c10
::
irange
(
begin
,
end
))
{
auto
in_ptr
=
weight_ptr
+
i
*
_block_k
*
block_n
;
auto
out_ptr
=
blocked_weight_ptr
+
i
*
block_n
*
(
_block_k
/
2
+
sizeof
(
int32_t
));
int32_t
*
comp_in_prt
=
compensation_ptr
+
i
*
block_n
;
int32_t
*
comp_out_prt
=
(
int32_t
*
)(
void
*
)(
blocked_weight_ptr
+
i
*
block_n
*
(
_block_k
/
2
+
sizeof
(
int32_t
))
+
_block_k
*
block_n
/
2
);
constexpr
int
n_group_size
=
8
;
constexpr
int
vnni_size
=
4
;
constexpr
int
n_group
=
block_n
/
n_group_size
;
for
(
int
nb
=
0
;
nb
<
n_group
;
nb
+=
2
)
{
for
(
int
k
=
0
;
k
<
_block_k
;
k
+=
vnni_size
)
{
for
(
int
ni
=
0
;
ni
<
n_group_size
;
++
ni
)
{
for
(
int
ki
=
0
;
ki
<
vnni_size
;
++
ki
)
{
int
src_idx_1
=
nb
*
n_group_size
+
ni
+
(
k
+
ki
)
*
block_n
;
int
src_idx_2
=
(
nb
+
1
)
*
n_group_size
+
ni
+
(
k
+
ki
)
*
block_n
;
int
dst_idx
=
(
nb
/
2
*
n_group_size
+
ni
)
*
vnni_size
+
k
*
block_n
/
2
+
ki
;
uint8_t
src_1
=
*
(
in_ptr
+
src_idx_1
);
uint8_t
src_2
=
*
(
in_ptr
+
src_idx_2
);
uint8_t
dst
=
(
src_1
&
0x0f
)
|
((
src_2
&
0x0f
)
<<
4
);
*
(
out_ptr
+
dst_idx
)
=
dst
;
}
}
}
}
for
(
int
nb
=
0
;
nb
<
block_n
;
nb
++
)
{
*
(
comp_out_prt
+
nb
)
=
*
(
comp_in_prt
+
nb
);
}
}
});
return
std
::
make_tuple
(
std
::
move
(
blocked_weight
),
std
::
move
(
blocked_scales
),
std
::
move
(
blocked_qzeros
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
autoawq_to_int4pack
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
)
{
auto
bitshifts
=
at
::
tensor
({
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
},
at
::
kInt
)
*
4
;
auto
qweight_unsq
=
qweight
.
unsqueeze
(
-
1
);
auto
unpacked
=
at
::
bitwise_right_shift
(
qweight_unsq
,
bitshifts
)
&
0xF
;
auto
qweight_final
=
unpacked
.
flatten
(
-
2
).
transpose
(
-
1
,
-
2
).
to
(
at
::
kByte
);
auto
qzeros_unsq
=
qzeros
.
unsqueeze
(
-
1
);
auto
qzeros_unpacked
=
at
::
bitwise_right_shift
(
qzeros_unsq
,
bitshifts
)
&
0xF
;
auto
qzeros_final
=
qzeros_unpacked
.
flatten
(
-
2
).
to
(
at
::
kByte
);
return
std
::
make_tuple
(
qweight_final
,
qzeros_final
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_weight_packed_scale_zp
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
,
at
::
Tensor
scales
)
{
auto
res
=
autoawq_to_int4pack
(
qweight
,
qzeros
);
auto
_qweight
=
std
::
get
<
0
>
(
res
);
auto
_qzeros
=
std
::
get
<
1
>
(
res
);
auto
_scales
=
scales
;
_qzeros
=
_qzeros
.
transpose
(
-
2
,
-
1
).
contiguous
();
_scales
=
_scales
.
transpose
(
-
2
,
-
1
).
contiguous
();
if
(
_qweight
.
dim
()
==
3
)
{
int64_t
E
=
_qweight
.
size
(
0
);
int64_t
K
=
_qweight
.
size
(
2
);
int64_t
G
=
_scales
.
size
(
2
);
int64_t
group_size
=
K
/
G
;
int64_t
_block_k
=
get_4bit_block_k_size
(
group_size
);
int64_t
block_n
=
block_size_n
();
int64_t
Nc
=
_qweight
.
size
(
1
)
/
block_n
;
int64_t
Kc
=
K
/
_block_k
;
int64_t
buffer_size_nbytes
=
_block_k
*
block_n
/
2
+
block_n
*
sizeof
(
int32_t
);
auto
blocked_weight
=
at
::
empty
({
E
,
Nc
,
Kc
,
buffer_size_nbytes
},
_qweight
.
options
());
auto
blocked_scales
=
at
::
empty
({
E
,
Nc
,
G
,
block_n
},
_scales
.
options
()).
to
(
at
::
kFloat
);
auto
blocked_qzeros
=
at
::
empty
({
E
,
Nc
,
G
,
block_n
},
_qzeros
.
options
()).
to
(
at
::
kChar
);
for
(
int
i
=
0
;
i
<
_qweight
.
size
(
0
);
i
++
)
{
auto
res_
=
convert_int4_weight_packed_with_compensation
(
_qweight
[
i
],
_scales
[
i
],
_qzeros
[
i
]);
blocked_weight
[
i
]
=
std
::
get
<
0
>
(
res_
);
blocked_scales
[
i
]
=
std
::
get
<
1
>
(
res_
);
blocked_qzeros
[
i
]
=
std
::
get
<
2
>
(
res_
);
}
_qweight
=
blocked_weight
;
_scales
=
blocked_scales
;
_qzeros
=
blocked_qzeros
;
}
else
{
auto
res_
=
convert_int4_weight_packed_with_compensation
(
_qweight
,
_scales
,
_qzeros
);
_qweight
=
std
::
get
<
0
>
(
res_
);
_scales
=
std
::
get
<
1
>
(
res_
);
_qzeros
=
std
::
get
<
2
>
(
res_
);
}
return
std
::
make_tuple
(
_qweight
,
_qzeros
,
_scales
);
}
at
::
Tensor
int4_scaled_mm_cpu_with_quant
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
weight_scales
,
const
at
::
Tensor
&
weight_qzeros
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
at
::
ScalarType
output_dtype
)
{
RECORD_FUNCTION
(
"vllm::int4_scaled_mm_cpu_with_quant"
,
std
::
vector
<
c10
::
IValue
>
({
input
,
weight
}));
int64_t
M_a
=
input
.
size
(
0
);
int64_t
K_a
=
input
.
size
(
1
);
int64_t
lda
=
input
.
stride
(
0
);
const
auto
st
=
input
.
scalar_type
();
TORCH_CHECK
(
st
==
at
::
kBFloat16
||
st
==
at
::
kHalf
,
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half."
);
constexpr
bool
sym_quant_act
=
false
;
using
Tin
=
typename
ActDtype
<
sym_quant_act
>::
type
;
int64_t
act_buffer_size
=
M_a
*
K_a
+
M_a
*
sizeof
(
float
)
+
M_a
*
sizeof
(
int32_t
);
auto
act_buffer
=
at
::
empty
({
act_buffer_size
},
input
.
options
().
dtype
(
at
::
kByte
));
auto
Aq_data
=
act_buffer
.
data_ptr
<
uint8_t
>
();
auto
As_data
=
reinterpret_cast
<
float
*>
(
Aq_data
+
M_a
*
K_a
);
auto
Azp_data
=
reinterpret_cast
<
int32_t
*>
(
As_data
+
M_a
);
fill_val_stub
(
Azp_data
,
128
,
M_a
);
auto
out_sizes
=
input
.
sizes
().
vec
();
int64_t
N
=
weight_scales
.
size
(
0
)
*
weight_scales
.
size
(
-
1
);
out_sizes
.
back
()
=
N
;
auto
output
=
at
::
empty
(
out_sizes
,
input
.
options
());
int64_t
Nc
=
weight
.
size
(
0
);
int64_t
Kc
=
weight
.
size
(
1
);
int64_t
_block_k
=
K_a
/
Kc
;
TORCH_CHECK
(
N
==
Nc
*
BLOCK_N
,
"DA8W4: weight and input shapes mismatch"
);
int64_t
num_groups
=
weight_scales
.
size
(
1
);
const
uint8_t
*
b_ptr
=
weight
.
data_ptr
<
uint8_t
>
();
const
float
*
b_scales_ptr
=
weight_scales
.
data_ptr
<
float
>
();
const
int8_t
*
b_qzeros_ptr
=
weight_qzeros
.
data_ptr
<
int8_t
>
();
const
float
*
bias_ptr
=
bias
.
has_value
()
?
bias
.
value
().
data_ptr
<
float
>
()
:
nullptr
;
int
num_threads
=
at
::
get_num_threads
();
int64_t
temp_buffer_size
=
num_threads
*
BLOCK_M
*
BLOCK_N
*
sizeof
(
float
)
+
num_threads
*
_block_k
*
BLOCK_N
;
auto
c_temp_buffer
=
at
::
empty
({
temp_buffer_size
},
input
.
options
().
dtype
(
at
::
kChar
));
float
*
c_temp_ptr
=
(
float
*
)((
void
*
)(
c_temp_buffer
.
data_ptr
<
int8_t
>
()));
int8_t
*
dqB_temp_ptr
=
(
int8_t
*
)((
void
*
)(
c_temp_ptr
+
num_threads
*
BLOCK_M
*
BLOCK_N
));
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
"int4_scaled_mm_cpu", [&] { \
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
for (int64_t m = begin; m < end; ++m) { \
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
A_data + m * lda, K_a); \
} \
}); \
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
num_groups); \
});
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL
(
sym_quant_act
);
return
output
;
}
namespace
{
template
<
typename
scalar_t
>
inline
void
copy_stub
(
scalar_t
*
__restrict__
out
,
const
float
*
__restrict__
input
,
int64_t
size
)
{
using
Vec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
#pragma GCC unroll 4
for
(
int64_t
d
=
0
;
d
<
size
;
d
+=
Vec
::
size
())
{
fVec
x0
=
fVec
::
loadu
(
input
+
d
);
fVec
x1
=
fVec
::
loadu
(
input
+
d
+
fVec
::
size
());
Vec
res
=
convert_from_float_ext
<
scalar_t
>
(
x0
,
x1
);
res
.
store
(
out
+
d
);
}
}
}
// anonymous namespace
template
<
typename
scalar_t
>
void
tinygemm_kernel
(
scalar_t
*
C
,
float
*
C_temp
,
const
uint8_t
*
A
,
const
float
*
scales_a
,
const
int32_t
*
qzeros_a
,
const
uint8_t
*
B
,
const
float
*
scales_b
,
const
int8_t
*
qzeros_b
,
const
int32_t
*
compensation
,
int8_t
*
dqB_tmp
,
int64_t
M
,
int64_t
K
,
int64_t
lda
,
int64_t
ldc_f
,
int64_t
ldc_s
,
bool
store_out
,
bool
use_brgemm
)
{
_dequant_gemm_accum
<
false
,
BLOCK_N
,
BLOCK_N
/
2
>
(
C_temp
,
A
,
scales_a
,
qzeros_a
,
B
,
scales_b
,
qzeros_b
,
compensation
,
dqB_tmp
,
M
,
K
,
lda
,
ldc_f
,
use_brgemm
);
if
(
store_out
)
{
for
(
int64_t
m
=
0
;
m
<
M
;
++
m
)
{
copy_stub
<
scalar_t
>
(
C
+
m
*
ldc_s
,
C_temp
+
m
*
ldc_f
,
BLOCK_N
);
}
}
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
bool store_out, bool use_brgemm)
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
BFloat16
);
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
Half
);
at
::
Tensor
int4_scaled_mm_cpu
(
at
::
Tensor
&
x
,
at
::
Tensor
&
w
,
at
::
Tensor
&
w_zeros
,
at
::
Tensor
&
w_scales
,
std
::
optional
<
at
::
Tensor
>
bias
)
{
return
int4_scaled_mm_cpu_with_quant
(
x
,
w
,
w_scales
,
w_zeros
,
bias
,
x
.
scalar_type
());
}
csrc/cpu/torch_bindings.cpp
View file @
f09daea2
...
...
@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
at
::
ScalarType
out_dtype
,
bool
is_vnni
);
// Adapted from sglang: INT4 W4A8 kernels
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_weight_packed_scale_zp
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
,
at
::
Tensor
scales
);
at
::
Tensor
int4_scaled_mm_cpu
(
at
::
Tensor
&
x
,
at
::
Tensor
&
w
,
at
::
Tensor
&
w_zeros
,
at
::
Tensor
&
w_scales
,
std
::
optional
<
at
::
Tensor
>
bias
);
torch
::
Tensor
get_scheduler_metadata
(
const
int64_t
num_req
,
const
int64_t
num_heads_q
,
const
int64_t
num_heads_kv
,
const
int64_t
head_dim
,
...
...
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"
);
ops
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
// Adapted from sglang: INT4 W4A8 kernels
ops
.
def
(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)"
);
ops
.
impl
(
"convert_weight_packed_scale_zp"
,
torch
::
kCPU
,
&
convert_weight_packed_scale_zp
);
ops
.
def
(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor"
);
ops
.
impl
(
"int4_scaled_mm_cpu"
,
torch
::
kCPU
,
&
int4_scaled_mm_cpu
);
#endif
// CPU attention kernels
...
...
tests/kernels/test_awq_int4_to_int8.py
0 → 100644
View file @
f09daea2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AWQ INT4 W4A8 GEMM pipeline (SGLang kernel migration).
Part 1: Weight packing tests
- convert_weight_packed_scale_zp correctness
Part 2: INT4 W4A8 GEMM tests
- int4_scaled_mm_cpu correctness w.r.t. float reference
- Bias, 3D input, various shapes
Part 3: create_weights shapes
cmd:
VLLM_CPU_INT4_W4A8=1 python -m pytest tests/kernels/test_awq_int4_to_int8.py -v -s
"""
import
numpy
as
np
import
pytest
import
torch
from
vllm._custom_ops
import
_supports_cpu_w4a8_int8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_cols
,
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cpu
():
pytest
.
skip
(
"skipping CPU-only tests"
,
allow_module_level
=
True
)
requires_cpu_w4a8_int8
=
pytest
.
mark
.
skipif
(
not
_supports_cpu_w4a8_int8
,
reason
=
"Requires vLLM CPU build with SGLang INT4 W4A8 kernels"
,
)
def
make_awq_checkpoint_data
(
K
,
N
,
group_size
,
seed
=
42
):
"""Create synthetic AWQ checkpoint data in packed int32 format.
Returns:
packed_qweight: [K, N//8] int32 (AWQ interleaved + packed)
packed_qzeros: [num_groups, N//8] int32 (AWQ interleaved + packed)
scales: [num_groups, N] float32
float_ref: [K, N] float32, reference dequantized weights
weight_int4_orig: [K, N] int32, original int4 values (0-15)
zeros_int4_orig: [num_groups, N] int32, original zero points (0-15)
"""
rng
=
np
.
random
.
RandomState
(
seed
)
num_groups
=
K
//
group_size
weight_int4_orig
=
torch
.
from_numpy
(
rng
.
randint
(
0
,
16
,
size
=
(
K
,
N
)).
astype
(
np
.
int32
)
)
zeros_int4_orig
=
torch
.
from_numpy
(
rng
.
randint
(
0
,
16
,
size
=
(
num_groups
,
N
)).
astype
(
np
.
int32
)
)
scales
=
torch
.
from_numpy
((
rng
.
randn
(
num_groups
,
N
)
*
0.05
).
astype
(
np
.
float32
))
scales_exp
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
zeros_exp
=
zeros_int4_orig
.
repeat_interleave
(
group_size
,
dim
=
0
)
float_ref
=
(
weight_int4_orig
.
float
()
-
zeros_exp
.
float
())
*
scales_exp
awq_interleave
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
weight_interleaved
=
(
weight_int4_orig
.
reshape
(
-
1
,
8
)[:,
awq_interleave
].
reshape
(
K
,
N
).
contiguous
()
)
packed_qweight
=
pack_cols
(
weight_interleaved
,
4
,
K
,
N
)
zeros_interleaved
=
(
zeros_int4_orig
.
reshape
(
-
1
,
8
)[:,
awq_interleave
]
.
reshape
(
num_groups
,
N
)
.
contiguous
()
)
packed_qzeros
=
pack_cols
(
zeros_interleaved
,
4
,
num_groups
,
N
)
return
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
weight_int4_orig
,
zeros_int4_orig
,
)
class
TestConvertWeightPackedScaleZp
:
"""Tests for convert_weight_packed_scale_zp weightpacking."""
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"K,N,group_size"
,
[
(
128
,
128
,
128
),
(
256
,
256
,
128
),
(
512
,
256
,
64
),
],
)
def
test_packing_output_shapes
(
self
,
K
,
N
,
group_size
):
"""Packed outputs should have expected shapes."""
(
packed_qweight
,
packed_qzeros
,
scales
,
_
,
_
,
_
)
=
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
block_n
=
32
Nc
=
N
//
block_n
assert
blocked_w
.
dim
()
>=
2
,
(
f
"blocked_w should have >= 2 dims, got
{
blocked_w
.
dim
()
}
"
)
assert
blocked_s
.
size
(
0
)
==
Nc
,
(
f
"Expected Nc=
{
Nc
}
scale blocks, got
{
blocked_s
.
size
(
0
)
}
"
)
assert
blocked_zp
.
size
(
0
)
==
Nc
,
(
f
"Expected Nc=
{
Nc
}
qzeros blocks, got
{
blocked_zp
.
size
(
0
)
}
"
)
print
(
f
" [PASS] packing shapes K=
{
K
}
, N=
{
N
}
, gs=
{
group_size
}
: "
f
"blocked_w=
{
list
(
blocked_w
.
shape
)
}
, "
f
"blocked_s=
{
list
(
blocked_s
.
shape
)
}
, blocked_zp=
{
list
(
blocked_zp
.
shape
)
}
"
)
class
TestInt4ScaledMmCpu
:
"""Tests for int4_scaled_mm_cpu GEMM kernel."""
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"M,K,N,group_size"
,
[
(
1
,
128
,
128
,
128
),
(
4
,
256
,
256
,
128
),
(
16
,
512
,
256
,
64
),
(
32
,
256
,
512
,
128
),
(
64
,
512
,
512
,
128
),
],
)
def
test_gemm_vs_float_reference
(
self
,
M
,
K
,
N
,
group_size
):
"""INT4 W4A8 GEMM should approximate float matmul."""
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
pct95
=
torch
.
quantile
(
abs_diff
,
0.95
).
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
exceeds 5% threshold"
)
assert
pct95
<
ref_mag
*
0.15
,
(
f
"95th-pctile abs_diff
{
pct95
:.
4
f
}
exceeds 15% of ref magnitude"
)
print
(
f
" [PASS] INT4 GEMM correct: M=
{
M
}
, K=
{
K
}
, N=
{
N
}
"
)
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
8
,
32
])
def
test_gemm_with_bias
(
self
,
M
):
"""INT4 W4A8 GEMM with bias should match reference."""
K
,
N
,
group_size
=
256
,
128
,
128
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
bias
=
torch
.
randn
(
N
,
dtype
=
torch
.
float32
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
bias
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
+
bias
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
with bias exceeds 5%"
)
print
(
f
" [PASS] INT4 GEMM with bias: M=
{
M
}
"
)
@
requires_cpu_w4a8_int8
def
test_gemm_3d_input
(
self
):
"""apply() reshapes 3D input [B, S, K] -> [B*S, K] -> back to 3D."""
K
,
N
,
group_size
=
256
,
128
,
128
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
B
,
S
=
2
,
8
x_3d
=
torch
.
randn
(
B
,
S
,
K
,
dtype
=
torch
.
bfloat16
)
x_2d
=
x_3d
.
reshape
(
-
1
,
K
)
out_2d
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x_2d
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
out_3d
=
out_2d
.
reshape
(
B
,
S
,
N
)
ref_out
=
torch
.
mm
(
x_2d
.
float
(),
float_ref
).
reshape
(
B
,
S
,
N
)
assert
out_3d
.
shape
==
(
B
,
S
,
N
)
abs_diff
=
(
out_3d
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
f
"Mean relative error
{
mean_rel
:.
4
f
}
for 3D exceeds 5%"
print
(
f
" [PASS] 3D input [
{
B
}
,
{
S
}
,
{
K
}
] -> output [
{
B
}
,
{
S
}
,
{
N
}
]"
)
@
requires_cpu_w4a8_int8
def
test_gemm_fp16_input
(
self
):
"""INT4 GEMM should also work with fp16 input."""
K
,
N
,
group_size
,
M
=
256
,
256
,
128
,
8
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
abs_diff
.
mean
().
item
()
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
for fp16 exceeds 5%"
)
print
(
f
" [PASS] fp16 input M=
{
M
}
, K=
{
K
}
, N=
{
N
}
"
)
class
TestCreateWeightsUnchanged
:
"""Create_weights should still produce correct int4 placeholder shapes."""
@
pytest
.
mark
.
parametrize
(
"K,N,group_size"
,
[
(
128
,
128
,
128
),
(
256
,
256
,
128
),
(
512
,
256
,
64
),
],
)
def
test_int4_placeholder_shapes
(
self
,
K
,
N
,
group_size
):
"""Verify qweight, qzeros, scales shapes."""
pack_factor
=
8
num_groups
=
K
//
group_size
qweight
=
torch
.
empty
(
K
,
N
//
pack_factor
,
dtype
=
torch
.
int32
)
qzeros
=
torch
.
empty
(
num_groups
,
N
//
pack_factor
,
dtype
=
torch
.
int32
)
scales
=
torch
.
empty
(
num_groups
,
N
,
dtype
=
torch
.
bfloat16
)
assert
qweight
.
shape
==
(
K
,
N
//
pack_factor
)
assert
qzeros
.
shape
==
(
num_groups
,
N
//
pack_factor
)
assert
scales
.
shape
==
(
num_groups
,
N
)
print
(
f
" [PASS] create_weights shapes: K=
{
K
}
, N=
{
N
}
, gs=
{
group_size
}
"
)
vllm/_custom_ops.py
View file @
f09daea2
...
...
@@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
if
hasattr
(
torch
.
ops
.
_C
,
"convert_weight_packed_scale_zp"
):
@
register_fake
(
"_C::convert_weight_packed_scale_zp"
)
def
convert_weight_packed_scale_zp_fake
(
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
(
torch
.
empty_like
(
qweight
),
torch
.
empty_like
(
qzeros
),
torch
.
empty_like
(
scales
),
)
if
hasattr
(
torch
.
ops
.
_C
,
"int4_scaled_mm_cpu"
):
@
register_fake
(
"_C::int4_scaled_mm_cpu"
)
def
int4_scaled_mm_cpu_fake
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
w_zeros
:
torch
.
Tensor
,
w_scales
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
N
=
w_scales
.
size
(
0
)
*
w_scales
.
size
(
-
1
)
return
torch
.
empty
((
x
.
size
(
0
),
N
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
_supports_cpu_w4a8_int8
=
bool
(
hasattr
(
torch
.
ops
.
_C
,
"convert_weight_packed_scale_zp"
))
class
CPUDNNLGEMMHandler
:
def
__init__
(
self
)
->
None
:
self
.
handler_tensor
:
torch
.
Tensor
|
None
=
None
...
...
vllm/envs.py
View file @
f09daea2
...
...
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
VLLM_CPU_NUM_OF_RESERVED_CPU
:
int
|
None
=
None
VLLM_CPU_SGL_KERNEL
:
bool
=
False
VLLM_ZENTORCH_WEIGHT_PREPACK
:
bool
=
True
VLLM_CPU_INT4_W4A8
:
bool
=
True
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
:
Literal
[
"auto"
,
"nccl"
,
"shm"
]
=
"auto"
...
...
@@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZENTORCH_WEIGHT_PREPACK"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZENTORCH_WEIGHT_PREPACK"
,
"1"
))
),
# (CPU backend only) whether to use SGLang INT4 W4A8 kernels for AWQ.
"VLLM_CPU_INT4_W4A8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CPU_INT4_W4A8"
,
"1"
))),
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
...
...
vllm/model_executor/layers/quantization/cpu_wna16.py
View file @
f09daea2
...
...
@@ -7,9 +7,8 @@ import torch
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
transformers
import
PretrainedConfig
from
vllm._custom_ops
import
(
cpu_gemm_wna16
,
)
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
...
...
@@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
torch
.
set_printoptions
(
profile
=
"full"
,
linewidth
=
5000
,
sci_mode
=
False
)
layer
.
use_w4a8
=
envs
.
VLLM_CPU_INT4_W4A8
and
torch
.
cpu
.
_is_amx_tile_supported
()
if
layer
.
use_w4a8
:
self
.
_process_weights_sglang_int4
(
layer
)
else
:
self
.
_process_weights_woq
(
layer
)
def
_process_weights_woq
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Original WOQ int4 repack path."""
packed_weight
=
layer
.
qweight
.
data
packed_zeros
=
layer
.
qzeros
.
data
group_num
=
packed_zeros
.
size
(
0
)
...
...
@@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase):
)
zeros
=
pack_cols
(
zeros
,
bits
,
group_num
,
output_size
).
contiguous
()
# make 16 output channel as a block and transpose to
# the make the block contiguous
weight
=
pack_cols
(
weight
,
bits
,
input_size
,
output_size
)
weight
=
(
weight
.
view
(
input_size
,
-
1
,
16
//
pack_factor
)
...
...
@@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer
.
qweight
.
data
=
weight
layer
.
qzeros
.
data
=
zeros
def
_process_weights_sglang_int4
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""SGLang INT4 W4A8 path: pack int4 weights with VNNI reordering."""
packed_weight
=
layer
.
qweight
.
data
packed_zeros
=
layer
.
qzeros
.
data
scales
=
layer
.
scales
.
data
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_weight
,
packed_zeros
,
scales
)
layer
.
packed_weight
=
blocked_w
layer
.
packed_qzeros
=
blocked_zp
layer
.
packed_scales
=
blocked_s
layer
.
qweight
=
None
layer
.
qzeros
=
None
layer
.
scales
=
None
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
x
=
cpu_gemm_wna16
(
if
layer
.
use_w4a8
:
return
self
.
_apply_sglang_int4
(
layer
,
x
,
bias
)
return
self
.
_apply_woq
(
layer
,
x
,
bias
)
def
_apply_woq
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Original WOQ int4 GEMM path."""
x
=
ops
.
cpu_gemm_wna16
(
input
=
x
,
q_weight
=
layer
.
qweight
,
scales
=
layer
.
scales
,
...
...
@@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase):
)
return
x
def
_apply_sglang_int4
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""SGLang INT4 W4A8 GEMM path."""
x_shape
=
x
.
shape
x_2d
=
x
.
reshape
(
-
1
,
x_shape
[
-
1
])
if
len
(
x_shape
)
>
2
else
x
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x_2d
,
layer
.
packed_weight
,
layer
.
packed_qzeros
,
layer
.
packed_scales
,
bias
,
)
out
=
out
.
reshape
(
x_shape
[:
-
1
]
+
(
out
.
size
(
-
1
),))
if
len
(
x_shape
)
>
2
else
out
return
out
def
_get_isa_hint
(
dtype
:
torch
.
dtype
)
->
str
:
supports_amx
=
torch
.
cpu
.
_is_amx_tile_supported
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment