Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5eda21e7
Unverified
Commit
5eda21e7
authored
Oct 18, 2024
by
Li, Jiang
Committed by
GitHub
Oct 17, 2024
Browse files
[Hardware][CPU] compressed-tensor INT8 W8A8 AZP support (#9344)
parent
8e1cddcd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
452 additions
and
96 deletions
+452
-96
.buildkite/run-cpu-test.sh
.buildkite/run-cpu-test.sh
+4
-4
Dockerfile.cpu
Dockerfile.cpu
+0
-13
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+34
-6
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+39
-2
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+360
-57
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+15
-0
docs/source/getting_started/cpu-installation.rst
docs/source/getting_started/cpu-installation.rst
+0
-14
No files found.
.buildkite/run-cpu-test.sh
View file @
5eda21e7
...
@@ -32,10 +32,10 @@ docker exec cpu-test bash -c "
...
@@ -32,10 +32,10 @@ docker exec cpu-test bash -c "
--ignore=tests/models/decoder_only/language/test_danube3_4b.py"
# Mamba and Danube3-4B on CPU is not supported
--ignore=tests/models/decoder_only/language/test_danube3_4b.py"
# Mamba and Danube3-4B on CPU is not supported
# Run compressed-tensor test
# Run compressed-tensor test
#
docker exec cpu-test bash -c "
docker
exec
cpu-test bash
-c
"
#
pytest -s -v \
pytest -s -v
\
#
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup
\
#
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dyna
n
mic_per_token"
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
# Run AWQ test
# Run AWQ test
docker
exec
cpu-test bash
-c
"
docker
exec
cpu-test bash
-c
"
...
...
Dockerfile.cpu
View file @
5eda21e7
...
@@ -33,19 +33,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
...
@@ -33,19 +33,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade pip && \
pip install --upgrade pip && \
pip install -r requirements-build.txt
pip install -r requirements-build.txt
# install oneDNN
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
RUN --mount=type=cache,target=/root/.cache/ccache \
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
-DONEDNN_BUILD_DOC=OFF \
-DONEDNN_BUILD_EXAMPLES=OFF \
-DONEDNN_BUILD_TESTS=OFF \
-DONEDNN_BUILD_GRAPH=OFF \
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
cmake --build ./oneDNN/build --target install --config Release
FROM cpu-test-1 AS build
FROM cpu-test-1 AS build
WORKDIR /workspace/vllm
WORKDIR /workspace/vllm
...
...
cmake/cpu_extension.cmake
View file @
5eda21e7
include
(
FetchContent
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS ON
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_CXX_STANDARD 17
)
#
#
# Define environment variables for special configurations
# Define environment variables for special configurations
...
@@ -82,15 +85,40 @@ else()
...
@@ -82,15 +85,40 @@ else()
message
(
FATAL_ERROR
"vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support."
)
message
(
FATAL_ERROR
"vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support."
)
endif
()
endif
()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
#
if
(
AVX512_FOUND AND NOT AVX512_DISABLED
)
FetchContent_Declare
(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.5.3
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set
(
ONEDNN_LIBRARY_TYPE
"STATIC"
)
set
(
ONEDNN_BUILD_DOC
"OFF"
)
set
(
ONEDNN_BUILD_EXAMPLES
"OFF"
)
set
(
ONEDNN_BUILD_TESTS
"OFF"
)
set
(
ONEDNN_ENABLE_WORKLOAD
"INFERENCE"
)
set
(
ONEDNN_ENABLE_PRIMITIVE
"MATMUL;REORDER"
)
set
(
ONEDNN_BUILD_GRAPH
"OFF"
)
set
(
ONEDNN_ENABLE_JIT_PROFILING
"OFF"
)
set
(
ONEDNN_ENABLE_ITT_TASKS
"OFF"
)
set
(
ONEDNN_ENABLE_MAX_CPU_ISA
"OFF"
)
set
(
ONEDNN_ENABLE_CPU_ISA_HINTS
"OFF"
)
set
(
CMAKE_POLICY_DEFAULT_CMP0077 NEW
)
FetchContent_MakeAvailable
(
oneDNN
)
list
(
APPEND LIBS dnnl
)
endif
()
message
(
STATUS
"CPU extension compile flags:
${
CXX_COMPILE_FLAGS
}
"
)
message
(
STATUS
"CPU extension compile flags:
${
CXX_COMPILE_FLAGS
}
"
)
list
(
APPEND LIBS numa
)
list
(
APPEND LIBS numa
)
# Appending the dnnl library for the AVX2 and AVX512, as it is not utilized by Power architecture.
if
(
AVX2_FOUND OR AVX512_FOUND
)
list
(
APPEND LIBS dnnl
)
endif
()
#
#
# _C extension
# _C extension
#
#
...
...
csrc/cpu/cpu_types_x86.hpp
View file @
5eda21e7
...
@@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
...
@@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
void
save
(
float
*
ptr
)
const
{
_mm256_storeu_ps
(
ptr
,
reg
);
}
void
save
(
float
*
ptr
)
const
{
_mm256_storeu_ps
(
ptr
,
reg
);
}
};
};
#ifdef __AVX512F__
struct
INT32Vec16
:
public
Vec
<
INT32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
__m512i
reg
;
int32_t
values
[
VEC_ELEM_NUM
];
};
__m512i
reg
;
explicit
INT32Vec16
(
const
void
*
data_ptr
)
:
reg
(
_mm512_loadu_epi32
(
data_ptr
))
{}
void
save
(
int32_t
*
ptr
)
const
{
_mm512_storeu_epi32
(
ptr
,
reg
);
}
void
save
(
int32_t
*
ptr
,
const
int
elem_num
)
const
{
constexpr
uint32_t
M
=
0xFFFFFFFF
;
__mmask16
mask
=
_cvtu32_mask16
(
M
>>
(
32
-
elem_num
));
_mm512_mask_storeu_epi32
(
ptr
,
mask
,
reg
);
}
};
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
constexpr
static
int
VEC_ELEM_NUM
=
16
;
...
@@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
...
@@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit
FP32Vec16
(
__m512
data
)
:
reg
(
data
)
{}
explicit
FP32Vec16
(
__m512
data
)
:
reg
(
data
)
{}
explicit
FP32Vec16
(
const
FP32Vec16
&
data
)
:
reg
(
data
.
reg
)
{}
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
:
reg
((
__m512
)
_mm512_inserti32x4
(
:
reg
((
__m512
)
_mm512_inserti32x4
(
_mm512_inserti32x4
(
_mm512_inserti32x4
(
...
@@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
...
@@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
explicit
FP32Vec16
(
const
INT32Vec16
&
v
)
:
reg
(
_mm512_cvt_roundepi32_ps
(
v
.
reg
,
_MM_FROUND_TO_NEAREST_INT
|
_MM_FROUND_NO_EXC
))
{}
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
_mm512_mul_ps
(
reg
,
b
.
reg
));
return
FP32Vec16
(
_mm512_mul_ps
(
reg
,
b
.
reg
));
}
}
...
@@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
...
@@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return
FP32Vec16
(
_mm512_mask_max_ps
(
reg
,
mask
,
reg
,
b
.
reg
));
return
FP32Vec16
(
_mm512_mask_max_ps
(
reg
,
mask
,
reg
,
b
.
reg
));
}
}
FP32Vec16
min
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
_mm512_min_ps
(
reg
,
b
.
reg
));
}
FP32Vec16
min
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
constexpr
uint32_t
M
=
0xFFFFFFFF
;
__mmask16
mask
=
_cvtu32_mask16
(
M
>>
(
32
-
elem_num
));
return
FP32Vec16
(
_mm512_mask_min_ps
(
reg
,
mask
,
reg
,
b
.
reg
));
}
FP32Vec16
abs
()
const
{
FP32Vec16
abs
()
const
{
return
FP32Vec16
(
_mm512_abs_ps
(
reg
));
return
FP32Vec16
(
_mm512_abs_ps
(
reg
));
}
}
...
@@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
...
@@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float
reduce_max
()
const
{
return
_mm512_reduce_max_ps
(
reg
);
}
float
reduce_max
()
const
{
return
_mm512_reduce_max_ps
(
reg
);
}
float
reduce_min
()
const
{
return
_mm512_reduce_min_ps
(
reg
);
}
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
constexpr
uint32_t
base_mask
=
(
0xFFFF
>>
(
16
-
group_size
));
constexpr
uint32_t
base_mask
=
(
0xFFFF
>>
(
16
-
group_size
));
...
...
csrc/cpu/quant.cpp
View file @
5eda21e7
...
@@ -5,25 +5,29 @@ namespace {
...
@@ -5,25 +5,29 @@ namespace {
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
struct
KernelVecType
{
struct
KernelVecType
{
using
load_vec_type
=
void
;
using
load_vec_type
=
void
;
using
azp_adj_load_vec_type
=
void
;
using
cvt_vec_type
=
void
;
using
cvt_vec_type
=
void
;
};
};
template
<
>
template
<
>
struct
KernelVecType
<
float
>
{
struct
KernelVecType
<
float
>
{
using
load_vec_type
=
vec_op
::
FP32Vec16
;
using
load_vec_type
=
vec_op
::
FP32Vec16
;
using
azp_adj_load_vec_type
=
vec_op
::
INT32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
};
};
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
BFloat16
>
{
struct
KernelVecType
<
c10
::
BFloat16
>
{
using
load_vec_type
=
vec_op
::
BF16Vec16
;
using
load_vec_type
=
vec_op
::
BF16Vec16
;
using
azp_adj_load_vec_type
=
vec_op
::
INT32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
};
};
#ifdef __AVX512F__
#ifdef __AVX512F__
template
<
typename
scalar_t
>
template
<
bool
AZP
,
typename
scalar_t
>
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
const
float
*
scale
,
const
int
num_tokens
,
const
float
*
scale
,
const
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
...
@@ -37,62 +41,110 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
...
@@ -37,62 +41,110 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const
cvt_vec_t
i8_min_vec
(
i8_min
);
const
cvt_vec_t
i8_min_vec
(
i8_min
);
const
cvt_vec_t
i8_max_vec
(
i8_max
);
const
cvt_vec_t
i8_max_vec
(
i8_max
);
cvt_vec_t
zp_vec
;
if
constexpr
(
AZP
)
{
zp_vec
=
cvt_vec_t
(
static_cast
<
float
>
(
*
azp
));
}
#pragma omp parallel for
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
int
j
=
0
;
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
).
clamp
(
i8_min_vec
,
i8_max_vec
);
elems_fp32
=
elems_fp32
*
inv_scale
;
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
zp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
}
}
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
).
clamp
(
i8_min_vec
,
i8_max_vec
);
elems_fp32
=
elems_fp32
*
inv_scale
;
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
if
(
j
+
vec_elem_num
==
hidden_size
)
{
if
constexpr
(
AZP
)
{
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
+
zp_vec
;
}
else
{
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
}
}
template
<
typename
scalar_t
>
template
<
bool
AZP
,
typename
scalar_t
>
void
dynamic_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
void
dynamic_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
float
*
scale
,
const
int
num_tokens
,
float
*
scale
,
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
constexpr
float
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
constexpr
float
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
const
cvt_vec_t
i8_min_vec
(
i8_min
);
const
cvt_vec_t
i8_max_vec
(
i8_max
);
#pragma omp parallel for
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
cvt_vec_t
max_abs
(
0.0
);
cvt_vec_t
max_value
(
std
::
numeric_limits
<
float
>::
lowest
());
cvt_vec_t
min_value
(
std
::
numeric_limits
<
float
>::
max
());
{
{
int
j
=
0
;
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
max_abs
=
max_abs
.
max
(
elems_fp32
.
abs
());
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
);
min_value
=
min_value
.
min
(
elems_fp32
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
());
}
}
}
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
if
(
j
+
vec_elem_num
==
hidden_size
)
{
if
(
j
+
vec_elem_num
==
hidden_size
)
{
max_abs
=
max_abs
.
max
(
elems_fp32
.
abs
());
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
);
min_value
=
min_value
.
min
(
elems_fp32
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
());
}
}
else
{
}
else
{
max_abs
=
max_abs
.
max
(
elems_fp32
.
abs
(),
hidden_size
-
j
);
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
,
hidden_size
-
j
);
min_value
=
min_value
.
min
(
elems_fp32
,
hidden_size
-
j
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
(),
hidden_size
-
j
);
}
}
}
}
}
float
scale_val
=
max_abs
.
reduce_max
()
/
127.0
f
;
float
scale_val
,
azp_val
;
scale
[
i
]
=
scale_val
;
if
constexpr
(
AZP
)
{
float
max_scalar
=
max_value
.
reduce_max
();
float
min_scalar
=
min_value
.
reduce_min
();
scale_val
=
(
max_scalar
-
min_scalar
)
/
255.0
f
;
azp_val
=
std
::
nearbyint
(
-
128.0
f
-
min_scalar
/
scale_val
);
azp
[
i
]
=
static_cast
<
int32_t
>
(
azp_val
);
scale
[
i
]
=
scale_val
;
}
else
{
scale_val
=
max_value
.
reduce_max
()
/
127.0
f
;
scale
[
i
]
=
scale_val
;
}
const
cvt_vec_t
inv_scale
(
1.0
/
scale_val
);
const
cvt_vec_t
inv_scale
(
1.0
/
scale_val
);
const
cvt_vec_t
azp_vec
(
azp_val
);
{
{
int
j
=
0
;
int
j
=
0
;
...
@@ -100,6 +152,11 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
...
@@ -100,6 +152,11 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
azp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
}
}
...
@@ -107,34 +164,111 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
...
@@ -107,34 +164,111 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
if
(
j
+
vec_elem_num
==
hidden_size
)
{
if
constexpr
(
AZP
)
{
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
+
azp_vec
;
}
else
{
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
}
}
}
}
template
<
bool
Bias
,
typename
scalar_t
>
template
<
bool
PerChannel
,
typename
scalar_t
>
void
dynamic_output_scale_impl
(
const
float
*
input
,
scalar_t
*
output
,
void
static_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
*
scale
,
const
scalar_t
*
bias
,
const
float
a_scale
,
const
float
*
b_scale
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int32_t
*
azp_with_adj
,
const
int
num_tokens
,
const
int
hidden_size
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_output_scale_impl
)
CPU_KERNEL_GUARD_IN
(
dynamic_output_scale_impl
)
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
azp_adj_load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
azp_adj_load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
cvt_vec_t
a_scale_vec
(
a_scale
);
cvt_vec_t
b_scale_vec
(
*
b_scale
);
cvt_vec_t
scale_vec
=
a_scale_vec
*
b_scale_vec
;
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
azp_adj_load_vec_t
azp_adj_vec
(
azp_with_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
if
constexpr
(
PerChannel
)
{
b_scale_vec
=
cvt_vec_t
(
b_scale
+
j
);
scale_vec
=
b_scale_vec
*
a_scale_vec
;
}
elems_fp32
=
elems_fp32
-
scale_vec
*
azp_adj_fp32
;
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
);
}
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
azp_adj_load_vec_t
azp_adj_vec
(
azp_with_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
if
constexpr
(
PerChannel
)
{
b_scale_vec
=
cvt_vec_t
(
b_scale
+
j
);
scale_vec
=
b_scale_vec
*
a_scale_vec
;
}
elems_fp32
=
elems_fp32
-
scale_vec
*
azp_adj_fp32
;
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
template
<
bool
AZP
,
bool
PerChannel
,
bool
Bias
,
typename
scalar_t
>
void
dynamic_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
*
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp
,
const
int32_t
*
azp_adj
,
const
scalar_t
*
bias
,
const
int
num_tokens
,
const
int
hidden_size
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_quant_epilogue
)
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
azp_adj_load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
azp_adj_load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
#pragma omp parallel for
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
int
j
=
0
;
int
j
=
0
;
cvt_vec_t
token_scale_vec
(
scale
[
i
]);
cvt_vec_t
token_scale_vec
(
a_scale
[
i
]);
cvt_vec_t
token_zp_scale_vec
;
if
constexpr
(
AZP
)
{
float
zp_scale_val
=
a_scale
[
i
]
*
static_cast
<
float
>
(
azp
[
i
]);
if
constexpr
(
!
PerChannel
)
{
zp_scale_val
*=
*
b_scale
;
}
token_zp_scale_vec
=
cvt_vec_t
(
zp_scale_val
);
}
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
*
token_scale_vec
;
elems_fp32
=
elems_fp32
*
token_scale_vec
;
if
constexpr
(
AZP
)
{
azp_adj_load_vec_t
azp_adj_vec
(
azp_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
azp_adj_fp32
=
azp_adj_fp32
*
token_zp_scale_vec
;
if
constexpr
(
PerChannel
)
{
cvt_vec_t
b_scale_vec
(
b_scale
+
j
);
azp_adj_fp32
=
azp_adj_fp32
*
b_scale_vec
;
}
elems_fp32
=
elems_fp32
-
azp_adj_fp32
;
}
if
constexpr
(
Bias
)
{
if
constexpr
(
Bias
)
{
load_vec_t
bias_vec
(
bias
+
j
);
load_vec_t
bias_vec
(
bias
+
j
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
...
@@ -148,6 +282,19 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
...
@@ -148,6 +282,19 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
*
token_scale_vec
;
elems_fp32
=
elems_fp32
*
token_scale_vec
;
if
constexpr
(
AZP
)
{
azp_adj_load_vec_t
azp_adj_vec
(
azp_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
azp_adj_fp32
=
azp_adj_fp32
*
token_zp_scale_vec
;
if
constexpr
(
PerChannel
)
{
cvt_vec_t
b_scale_vec
(
b_scale
+
j
);
azp_adj_fp32
=
azp_adj_fp32
*
b_scale_vec
;
}
elems_fp32
=
elems_fp32
-
azp_adj_fp32
;
}
if
constexpr
(
Bias
)
{
if
constexpr
(
Bias
)
{
load_vec_t
bias_vec
(
bias
+
j
);
load_vec_t
bias_vec
(
bias
+
j
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
...
@@ -155,32 +302,41 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
...
@@ -155,32 +302,41 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
}
}
load_vec_t
elems_out
(
elems_fp32
);
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
if
(
j
+
vec_elem_num
==
hidden_size
)
{
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
);
}
else
{
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
}
}
}
#else
#else
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
const
float
*
scale
,
const
int
num_tokens
,
const
float
*
scale
,
const
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512 support."
)
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512 support."
)
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
dynamic_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
void
dynamic_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
float
*
scale
,
const
int
num_tokens
,
float
*
scale
,
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires AVX512 support."
)
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires AVX512 support."
)
}
}
template
<
bool
PerChannel
,
typename
scalar_t
>
void
static_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp_with_adj
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_quant_epilogue requires AVX512 support."
)
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
dynamic_output_scale_impl
()
{
void
dynamic_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
TORCH_CHECK
(
false
,
"dynamic_output_scale_impl requires AVX512 support."
)
const
float
*
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp
,
const
int32_t
*
azp_with_adj
,
const
scalar_t
*
bias
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_quant_epilogue requires AVX512 support."
)
}
}
#endif
#endif
}
// namespace
}
// namespace
...
@@ -214,39 +370,52 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
...
@@ -214,39 +370,52 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
bias
->
dim
()
==
1
);
bias
->
dim
()
==
1
);
}
}
VLLM_DISPATCH_FLOATING_TYPES
(
c
.
scalar_type
(),
"
cutlass
_scaled_mm"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
c
.
scalar_type
(),
"
int8
_scaled_mm"
,
[
&
]
{
if
(
a_scales
.
numel
()
!=
1
)
{
if
(
a_scales
.
numel
()
!=
1
)
{
// per-token
// per-token
// Note: oneDNN doesn't support per-token activation quantization
// Note: oneDNN doesn't support per-token activation quantization
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
// JIT, the intermediate data is cached in registers or L1. But for now
// the oneDNN GEMM code generation only supports two quantization
// patterns: per-tensor or per-output-channel of weight.
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
// GEMM, then the per-token scale (and bias) is applied with the epilogue
// C=s_a * C_inter + bias.
torch
::
Tensor
tmp_fp32_out
=
torch
::
Tensor
tmp_fp32_out
=
torch
::
empty_like
(
c
,
::
at
::
ScalarType
::
Float
);
torch
::
empty_like
(
c
,
::
at
::
ScalarType
::
Float
);
DNNLPrimitiveHelper
<
true
>::
gemm_s8s8_jit
(
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper
<
true
>::
gemm_s8s8_jit
<
float
,
void
>
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
(
void
*
)(
0
),
a
.
size
(
0
),
b
.
size
(
1
),
tmp_fp32_out
.
data_ptr
<
float
>
(),
nullptr
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
(
float
*
)(
0
),
b_scales
.
data_ptr
<
float
>
(),
0
,
a
.
size
(
1
),
nullptr
,
b_scales
.
data_ptr
<
float
>
(),
0
,
b_scales
.
numel
());
b_scales
.
numel
());
if
(
bias
.
has_value
())
{
if
(
bias
.
has_value
())
{
dynamic_output_scale_impl
<
true
>
(
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue
<
false
,
true
,
true
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
bias
->
data_ptr
<
scalar_t
>
(),
c
.
size
(
0
)
,
a_scales
.
data_ptr
<
float
>
(),
nullptr
,
nullptr
,
nullptr
,
c
.
size
(
1
));
bias
->
data_ptr
<
scalar_t
>
(),
c
.
size
(
0
),
c
.
size
(
1
));
}
else
{
}
else
{
dynamic_output_scale_impl
<
false
>
(
// Compute C=s_a * C_inter
dynamic_quant_epilogue
<
false
,
true
,
false
,
scalar_t
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
(
scalar_t
*
)(
0
),
c
.
size
(
0
),
c
.
size
(
1
));
a_scales
.
data_ptr
<
float
>
(),
nullptr
,
nullptr
,
nullptr
,
nullptr
,
c
.
size
(
0
),
c
.
size
(
1
));
}
}
}
else
{
}
else
{
// per-tensor
// per-tensor
if
(
bias
.
has_value
())
{
if
(
bias
.
has_value
())
{
// Compute C=s_a * s_b * (A@B) + bias
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
(
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
c
.
data_ptr
<
scalar_t
>
(),
bias
->
data_ptr
<
scalar_t
>
(),
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
bias
->
data_ptr
<
scalar_t
>
(),
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
(),
b_scales
.
numel
());
a_scales
.
numel
(),
b_scales
.
numel
());
}
else
{
}
else
{
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
(
// Compute C=s_a * s_b * (A@B)
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
<
scalar_t
,
void
>
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
c
.
data_ptr
<
scalar_t
>
(),
(
void
*
)(
0
)
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
nullptr
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
(),
b_scales
.
numel
());
a_scales
.
numel
(),
b_scales
.
numel
());
}
}
...
@@ -254,6 +423,127 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
...
@@ -254,6 +423,127 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
});
});
}
}
void
int8_scaled_mm_azp
(
torch
::
Tensor
&
c
,
// [M, OC], row-major
const
torch
::
Tensor
&
a
,
// [M, IC], row-major
const
torch
::
Tensor
&
b
,
// [IC, OC], column-major
const
torch
::
Tensor
&
a_scales
,
// [1] or [M]
const
torch
::
Tensor
&
b_scales
,
// [1] or [OC]
const
torch
::
Tensor
&
azp_adj
,
// [OC]
const
c10
::
optional
<
torch
::
Tensor
>&
azp
,
// [1] or [M]
const
c10
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
)
{
CPU_KERNEL_GUARD_IN
(
cutlass_scaled_mm_azp
)
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
&&
b
.
dtype
()
==
torch
::
kInt8
,
"int8_scaled_mm_azp only supports INT8 inputs."
)
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
());
}
if
(
azp
)
{
TORCH_CHECK
(
azp
->
numel
()
==
a
.
size
(
0
)
&&
azp
->
is_contiguous
());
}
TORCH_CHECK
(
azp_adj
.
numel
()
==
b
.
size
(
1
)
&&
azp_adj
.
is_contiguous
());
// azp & bias types
TORCH_CHECK
(
azp_adj
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
azp
||
azp
->
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
bias
||
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
VLLM_DISPATCH_FLOATING_TYPES
(
c
.
scalar_type
(),
"int8_scaled_mm_azp"
,
[
&
]
{
torch
::
Tensor
tmp_fp32_out
=
torch
::
empty_like
(
c
,
::
at
::
ScalarType
::
Float
);
if
(
a_scales
.
numel
()
!=
1
)
{
// per-token
// Note: oneDNN doesn't support per-token activation quantization
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper
<
true
>::
gemm_s8s8_jit
<
float
,
void
>
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
nullptr
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
nullptr
,
b_scales
.
data_ptr
<
float
>
(),
0
,
b_scales
.
numel
());
if
(
bias
.
has_value
())
{
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
if
(
b_scales
.
numel
()
!=
1
)
{
// Per-Channel
dynamic_quant_epilogue
<
true
,
true
,
true
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
bias
->
data_ptr
<
scalar_t
>
(),
c
.
size
(
0
),
c
.
size
(
1
));
}
else
{
// Per-Tensor
dynamic_quant_epilogue
<
true
,
false
,
true
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
bias
->
data_ptr
<
scalar_t
>
(),
c
.
size
(
0
),
c
.
size
(
1
));
}
}
else
{
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
if
(
b_scales
.
numel
()
!=
1
)
{
// Per-Channel
dynamic_quant_epilogue
<
true
,
true
,
false
,
scalar_t
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
nullptr
,
c
.
size
(
0
),
c
.
size
(
1
));
}
else
{
// Per-Tensor
dynamic_quant_epilogue
<
true
,
false
,
false
,
scalar_t
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
nullptr
,
c
.
size
(
0
),
c
.
size
(
1
));
}
}
}
else
{
// per-tensor
if
(
bias
.
has_value
())
{
// Compute C_inter=s_a * s_b * (A@B) + bias
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
bias
->
data_ptr
<
scalar_t
>
(),
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
(),
b_scales
.
numel
());
}
else
{
// Compute C_inter=s_a * s_b * (A@B)
DNNLPrimitiveHelper
<
false
>::
gemm_s8s8_jit
<
float
,
void
>
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
nullptr
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
(),
b_scales
.
numel
());
}
// Compute C=C_inter - s_a * s_b * azp_adj
if
(
b_scales
.
numel
()
!=
1
)
{
// Per-Channel
static_quant_epilogue
<
true
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
*
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
a
.
size
(
0
),
b
.
size
(
1
));
}
else
{
// Per-Tensor
static_quant_epilogue
<
false
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
*
a_scales
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
float
>
(),
azp_adj
.
data_ptr
<
int32_t
>
(),
a
.
size
(
0
),
b
.
size
(
1
));
}
}
});
}
// static-per-tensor quantization.
// static-per-tensor quantization.
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
...
@@ -263,15 +553,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -263,15 +553,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
!
azp
.
has_value
()
,
"Zero point is not supported on CPU."
);
TORCH_CHECK
(
!
azp
.
has_value
()
||
azp
->
numel
()
==
1
);
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_impl"
,
[
&
]
{
input
.
scalar_type
(),
"static_scaled_int8_quant_impl"
,
[
&
]
{
static_scaled_int8_quant_impl
(
if
(
azp
.
has_value
())
{
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
static_scaled_int8_quant_impl
<
true
>
(
scale
.
data_ptr
<
float
>
(),
num_tokens
,
hidden_size
);
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
num_tokens
,
hidden_size
);
}
else
{
static_scaled_int8_quant_impl
<
false
>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
nullptr
,
num_tokens
,
hidden_size
);
}
});
});
}
}
...
@@ -284,14 +581,20 @@ void dynamic_scaled_int8_quant(
...
@@ -284,14 +581,20 @@ void dynamic_scaled_int8_quant(
CPU_KERNEL_GUARD_IN
(
dynamic_scaled_int8_quant
)
CPU_KERNEL_GUARD_IN
(
dynamic_scaled_int8_quant
)
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
!
azp
.
has_value
(),
"Zero point is not supported on CPU."
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_impl"
,
[
&
]
{
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_impl"
,
[
&
]
{
dynamic_scaled_int8_quant_impl
(
if
(
azp
.
has_value
())
{
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
dynamic_scaled_int8_quant_impl
<
true
>
(
scale
.
data_ptr
<
float
>
(),
num_tokens
,
hidden_size
);
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
num_tokens
,
hidden_size
);
}
else
{
dynamic_scaled_int8_quant_impl
<
false
>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
nullptr
,
num_tokens
,
hidden_size
);
}
});
});
}
}
csrc/cpu/torch_bindings.cpp
View file @
5eda21e7
...
@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
...
@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
b_scales
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
void
int8_scaled_mm_azp
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
azp_adj
,
const
c10
::
optional
<
torch
::
Tensor
>&
azp
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// vLLM custom ops
...
@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCPU
,
&
int8_scaled_mm
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCPU
,
&
int8_scaled_mm
);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCPU
,
&
int8_scaled_mm_azp
);
#endif
#endif
}
}
...
...
docs/source/getting_started/cpu-installation.rst
View file @
5eda21e7
...
@@ -59,20 +59,6 @@ Build from source
...
@@ -59,20 +59,6 @@ Build from source
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Third, build and install oneDNN library from source:
.. code-block:: console
$ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
$ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
-DONEDNN_BUILD_DOC=OFF \
-DONEDNN_BUILD_EXAMPLES=OFF \
-DONEDNN_BUILD_TESTS=OFF \
-DONEDNN_BUILD_GRAPH=OFF \
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
-DONEDNN_ENABLE_PRIMITIVE=MATMUL
$ cmake --build ./oneDNN/build --target install --config Release
- Finally, build and install vLLM CPU backend:
- Finally, build and install vLLM CPU backend:
.. code-block:: console
.. code-block:: console
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment