Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5d73ae49
Unverified
Commit
5d73ae49
authored
Sep 16, 2024
by
Luka Govedič
Committed by
GitHub
Sep 16, 2024
Browse files
[Kernel] AQ AZP 3/4: Asymmetric quantization kernels (#7270)
parent
781e3b9a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
339 additions
and
57 deletions
+339
-57
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+6
-3
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+5
-4
csrc/ops.h
csrc/ops.h
+4
-2
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+160
-13
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-4
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+138
-20
vllm/_custom_ops.py
vllm/_custom_ops.py
+20
-9
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+1
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+1
-1
No files found.
csrc/cpu/quant.cpp
View file @
5d73ae49
...
...
@@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
// static-per-tensor quantization.
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
const
torch
::
Tensor
&
scale
)
{
const
torch
::
Tensor
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
CPU_KERNEL_GUARD_IN
(
static_scaled_int8_quant
)
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
!
azp
.
has_value
(),
"Zero point is not supported on CPU."
);
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
const
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scale
// [..., 1]
)
{
torch
::
Tensor
&
scale
,
// [..., 1]
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_scaled_int8_quant
)
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
!
azp
.
has_value
(),
"Zero point is not supported on CPU."
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
csrc/cpu/torch_bindings.cpp
View file @
5d73ae49
...
...
@@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#ifdef __AVX512F__
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale
) ->
"
"()"
);
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale
,
"
"
Tensor? azp) ->
()"
);
ops
.
impl
(
"static_scaled_int8_quant"
,
torch
::
kCPU
,
&
static_scaled_int8_quant
);
// Compute int8 quantized tensor and scaling factor
ops
.
def
(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale
) ->
"
"()"
);
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale
,
"
"
Tensor!? azp) ->
()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCPU
,
&
dynamic_scaled_int8_quant
);
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
...
...
csrc/ops.h
View file @
5d73ae49
...
...
@@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
5d73ae49
...
...
@@ -14,12 +14,17 @@
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
static
const
float
i8_min
=
static
const
expr
auto
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
const
float
i8_max
=
static
const
expr
auto
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// round
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
...
...
@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif
}
static
inline
__device__
int32_t
float_to_int32_rn
(
float
x
)
{
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static
constexpr
auto
i32_min
=
std
::
numeric_limits
<
int32_t
>::
min
();
static
constexpr
auto
i32_min_f
=
static_cast
<
float
>
(
i32_min
);
static
constexpr
auto
i32_max
=
std
::
numeric_limits
<
int32_t
>::
max
();
static
constexpr
auto
i32_max_f
=
static_cast
<
float
>
(
i32_max
);
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate on the higher end.
if
(
dst
>=
i32_max_f
)
{
return
i32_max
;
}
// saturate on the lower end.
if
(
dst
<=
i32_min_f
)
{
return
i32_min
;
}
return
static_cast
<
int32_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.rni.sat.s32.f32 %0, %1;"
:
"=r"
(
dst
)
:
"f"
(
x
));
return
reinterpret_cast
<
const
int32_t
&>
(
dst
);
#endif
}
static
inline
__device__
int8_t
int32_to_int8
(
int32_t
x
)
{
#ifdef USE_ROCM
static
constexpr
auto
i8_min
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
constexpr
auto
i8_max
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.sat.s8.s32 %0, %1;"
:
"=r"
(
dst
)
:
"r"
(
x
));
return
reinterpret_cast
<
const
int8_t
&>
(
dst
);
#endif
}
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scale_type
>
...
...
@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
static_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
}
}
template
<
typename
scalar_t
,
typename
scale_type
>
__global__
void
dynamic_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
...
...
@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int
const
token_idx
=
blockIdx
.
x
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
// Reduce the max and min values across the block
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
max_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
max_val
,
cub
::
Max
{},
blockDim
.
x
);
__syncthreads
();
// Make sure min doesn't mess with max shared memory
min_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
min_val
,
cub
::
Min
{},
blockDim
.
x
);
__shared__
scale_type
scale_sh
;
__shared__
azp_type
azp_sh
;
// Compute the scale and zero point and store them, only on the first thread
if
(
threadIdx
.
x
==
0
)
{
float
const
scale_val
=
(
max_val
-
min_val
)
/
255.0
f
;
// Use rounding to even (same as torch.round)
auto
const
azp_float
=
std
::
nearbyint
(
-
128.0
f
-
min_val
/
scale_val
);
auto
const
azp_val
=
static_cast
<
azp_type
>
(
azp_float
);
// Store the scale and azp into shared and global
scale
[
token_idx
]
=
scale_sh
=
scale_val
;
azp
[
token_idx
]
=
azp_sh
=
azp_val
;
}
// Wait for the scale and azp to be computed
__syncthreads
();
float
const
scale_val
=
scale_sh
;
azp_type
const
azp_val
=
azp_sh
;
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
token_idx
*
hidden_size
+
i
]
=
quant_val
;
}
}
}
// namespace vllm
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
scale
)
{
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
!
azp
||
azp
->
numel
()
==
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
if
(
!
azp
)
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
static_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scales
)
{
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scales
.
is_contiguous
());
TORCH_CHECK
(
!
azp
||
azp
->
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
if
(
!
azp
)
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
dynamic_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
csrc/torch_bindings.cpp
View file @
5d73ae49
...
...
@@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale
) ->
"
"()"
);
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale
,
"
"
Tensor? azp) ->
()"
);
ops
.
impl
(
"static_scaled_int8_quant"
,
torch
::
kCUDA
,
&
static_scaled_int8_quant
);
// Compute int8 quantized tensor and scaling factor
ops
.
def
(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale
) ->
"
"()"
);
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale
,
"
"
Tensor!? azp) ->
()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_int8_quant
);
}
...
...
tests/kernels/test_int8_quant.py
View file @
5d73ae49
...
...
@@ -13,14 +13,28 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant
(
output
,
input
,
scale
=
None
):
if
scale
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
))
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
azp
is
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
))
if
symmetric
:
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
azp
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -38,14 +52,56 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# reference
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
int8
)
# kernel
ops_out
,
ops_scales
=
scaled_int8_quant
(
x
)
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant
(
ops_out
,
x
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_dynamic_scaled_int8_azp_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
-
300
x_token_max
,
_
=
x
.
to
(
dtype
=
torch
.
float32
).
max
(
dim
=
1
,
keepdim
=
True
)
x_token_min
,
_
=
x
.
to
(
dtype
=
torch
.
float32
).
min
(
dim
=
1
,
keepdim
=
True
)
# calculate scale and azp, and adjust the range
scales
=
(
x_token_max
-
x_token_min
)
/
torch
.
tensor
(
255.0
)
azps
=
torch
.
round
(
torch
.
tensor
(
-
128.0
)
-
x_token_min
/
scales
).
to
(
torch
.
int32
)
torch_out
=
((
x
/
scales
).
round
()
+
azps
).
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
assert
torch_out
.
min
()
>=
int8_traits
.
min
and
torch_out
.
max
(
)
<=
int8_traits
.
max
ops_out
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
scales_out
=
torch
.
empty_like
(
scales
,
dtype
=
torch
.
float32
)
azp_out
=
torch
.
empty_like
(
azps
,
dtype
=
torch
.
int32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
ops_out
,
x
,
scales_out
,
azp_out
)
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -62,14 +118,76 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
scale
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
scale
_arg
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
out1
=
(
x
/
scale
).
round
().
clamp
(
int8_traits
.
min
,
out1
=
(
x
/
scale
_arg
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
=
scaled_int8_quant
(
x
,
scale
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
SCALE
[
2
:])
# Reduce test time
@
pytest
.
mark
.
parametrize
(
"azp"
,
[
-
255
,
54
])
@
torch
.
inference_mode
()
def
test_static_scaled_int8_azp_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
scale
:
float
,
azp
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1000
-
300
out1
=
((
x
/
scale
).
round
()
+
azp
).
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
scale_arg
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
azp_arg
=
torch
.
tensor
([
azp
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_static_scaled_int8_azp_quant_saturating_cast
(
is_max
:
bool
)
->
None
:
# Test that the saturating cast works correctly for values near i32 max/min
from
numpy
import
inf
,
nextafter
int32_traits
=
torch
.
iinfo
(
torch
.
int32
)
val
=
float
(
int32_traits
.
max
if
is_max
else
int32_traits
.
min
)
x_vals
=
[[
nextafter
(
val
,
inf
),
val
+
1
,
val
,
val
-
1
,
nextafter
(
val
,
-
inf
)
]]
x
=
torch
.
tensor
(
x_vals
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
# where cast<T> is a saturating cast to type T.
# Scale is set to 1.0 so that the input values are the ones that are cast.
# AZP is set to 0 to make sure the int8 saturating cast is tested as well.
scale
=
torch
.
scalar_tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
azp
=
torch
.
scalar_tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
val_i8
=
int8_traits
.
max
if
is_max
else
int8_traits
.
min
expected
=
torch
.
full
((
1
,
5
),
val_i8
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
opcheck_int8_quant
(
out2
,
x
,
scale
)
out
=
torch
.
empty_like
(
expected
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
vllm/_custom_ops.py
View file @
5d73ae49
...
...
@@ -685,31 +685,42 @@ def scaled_fp8_quant(
# int8
def
scaled_int8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
azp
:
Optional
[
torch
.
Tensor
]
=
None
,
symmetric
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Quantize the input tensor to int8 and return the quantized tensor and scale.
Quantize the input tensor to int8 and return the quantized tensor and scale
, and maybe azp
.
Args:
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
azp: Optional zero-point for the int8 quantization.
Must be provided for asymmetric quantization if `scale` is provided.
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
Tuple[
T
orch.Tensor,
T
orch.Tensor] : Output int8 tensor
and scales
.
Tuple[
t
orch.Tensor,
torch.Tensor, Optional[t
orch.Tensor]
]
: Output int8 tensor
, scales, and optionally azp
.
"""
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
if
scale
is
not
None
:
# static-per-tensor quantization.
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
)
return
output
,
scale
assert
symmetric
==
(
azp
is
None
),
"azp must only be provided for asymmetric quantization."
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
,
azp
)
return
output
,
scale
,
None
# dynamic-per-token quantization.
input_scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
)
return
output
,
input_scales
input_azp
=
None
if
symmetric
else
torch
.
empty_like
(
input_scales
,
dtype
=
torch
.
int32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
,
input_azp
)
return
output
,
input_scales
,
input_azp
# qqq ops
...
...
vllm/model_executor/layers/quantization/qqq.py
View file @
5d73ae49
...
...
@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
=
ops
.
scaled_int8_quant
(
x_2d
)
x_int8
,
s_tok
,
_
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
5d73ae49
...
...
@@ -188,7 +188,7 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment