Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e76e2335
Unverified
Commit
e76e2335
authored
Aug 24, 2025
by
czhu-cohere
Committed by
GitHub
Aug 24, 2025
Browse files
[kernel] Support W4A8 on Hopper (#23198)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
a7527728
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1128 additions
and
7 deletions
+1128
-7
CMakeLists.txt
CMakeLists.txt
+27
-0
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+33
-0
benchmarks/kernels/weight_shapes.py
benchmarks/kernels/weight_shapes.py
+6
-0
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+418
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+20
-0
tests/kernels/quantization/test_cutlass_w4a8.py
tests/kernels/quantization/test_cutlass_w4a8.py
+259
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+48
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+37
-6
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+3
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
+160
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+3
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
...or/layers/quantization/kernels/mixed_precision/cutlass.py
+114
-0
No files found.
CMakeLists.txt
View file @
e76e2335
...
@@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"found in CUDA target architectures"
)
"found in CUDA target architectures"
)
endif
()
endif
()
endif
()
endif
()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection
(
W4A8_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
W4A8_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
message
(
STATUS
"Building W4A8 kernels for archs:
${
W4A8_ARCHS
}
"
)
else
()
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS
)
message
(
STATUS
"Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper."
)
else
()
message
(
STATUS
"Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures"
)
endif
()
endif
()
# if CUDA endif
# if CUDA endif
endif
()
endif
()
...
...
benchmarks/kernels/benchmark_machete.py
View file @
e76e2335
...
@@ -284,6 +284,25 @@ def machete_create_bench_fn(
...
@@ -284,6 +284,25 @@ def machete_create_bench_fn(
)
)
def
cutlass_w4a8_create_bench_fn
(
bt
:
BenchmarkTensors
,
out_type
=
torch
.
dtype
,
schedule
=
None
)
->
Callable
:
w_q
=
bt
.
w_q
.
t
().
contiguous
().
t
()
# make col major
w_q
=
ops
.
cutlass_encode_and_reorder_int4b
(
w_q
)
# expects fp8 scales
w_s
=
ops
.
cutlass_pack_scale_fp8
(
bt
.
w_g_s
.
to
(
torch
.
float8_e4m3fn
))
return
lambda
:
ops
.
cutlass_w4a8_mm
(
a
=
bt
.
a
,
b_q
=
w_q
,
b_group_scales
=
w_s
,
b_group_size
=
bt
.
group_size
,
b_channel_scales
=
bt
.
w_ch_s
,
a_token_scales
=
bt
.
w_tok_s
,
maybe_schedule
=
schedule
,
)
# impl
# impl
# bench
# bench
...
@@ -385,6 +404,20 @@ def bench(
...
@@ -385,6 +404,20 @@ def bench(
)
)
)
)
# cutlass w4a8
if
types
.
act_type
==
torch
.
float8_e4m3fn
and
group_size
==
128
:
timers
.
append
(
bench_fns
(
label
,
sub_label
,
f
"cutlass w4a8 (
{
name_type_string
}
)"
,
[
cutlass_w4a8_create_bench_fn
(
bt
,
out_type
=
types
.
output_type
)
for
bt
in
benchmark_tensors
],
)
)
if
sweep_schedules
:
if
sweep_schedules
:
global
_SWEEP_SCHEDULES_RESULTS
global
_SWEEP_SCHEDULES_RESULTS
...
...
benchmarks/kernels/weight_shapes.py
View file @
e76e2335
...
@@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
...
@@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
([
2048
,
2816
],
1
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
([
1408
,
2048
],
0
),
],
],
"CohereLabs/c4ai-command-a-03-2025"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
73728
],
1
),
([
36864
,
12288
],
0
),
],
}
}
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
0 → 100644
View file @
e76e2335
//
// Based off of:
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
::
cutlass_w4a8
{
using
namespace
cute
;
// -------------------------------------------------------------------------------------
// Static configuration shared across all instantiations
// -------------------------------------------------------------------------------------
using
MmaType
=
cutlass
::
float_e4m3_t
;
// A/scale element type
using
QuantType
=
cutlass
::
int4b_t
;
// B element type (packed int4)
static
int
constexpr
TileShapeK
=
128
*
8
/
sizeof_bits
<
MmaType
>::
value
;
static
int
constexpr
ScalePackSize
=
8
;
// pack 8 scale elements together
static
int
constexpr
PackFactor
=
8
;
// 8 4-bit packed into int32
// A matrix configuration
using
ElementA
=
MmaType
;
// Element type for A matrix operand
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
// Layout type for A matrix operand
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
using
StrideA
=
cutlass
::
detail
::
TagToStrideA_t
<
LayoutA
>
;
// B matrix configuration
using
ElementB
=
QuantType
;
// Element type for B matrix operand
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
// Layout type for B matrix operand
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
using
StrideB
=
cutlass
::
detail
::
TagToStrideB_t
<
LayoutB
>
;
// Define the CuTe layout for reordered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in
// contiguous locations in global memory. It specifies the reordering within a
// single warp's fragment
using
LayoutAtomQuant
=
decltype
(
cutlass
::
compute_memory_reordering_atom
<
MmaType
>
());
using
LayoutB_Reordered
=
decltype
(
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
Layout
<
Shape
<
int
,
int
,
int
>
,
StrideB
>
{}));
// Group-wise scales
using
ElementScale
=
MmaType
;
using
LayoutScale
=
cutlass
::
layout
::
RowMajor
;
// Per-tok, per-chan scales
using
ElementSChannel
=
float
;
// C/D matrix configuration
using
ElementC
=
cutlass
::
bfloat16_t
;
// Element type for C and D matrix operands
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
// Layout type for C and D matrix operands
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)
using
ElementD
=
ElementC
;
using
LayoutD
=
LayoutC
;
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Core kernel configurations
using
ElementAccumulator
=
float
;
// Element type for internal accumulation
using
ElementCompute
=
float
;
// Element type for epilogue computation
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
// Tag indicating the minimum SM that
// supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperative
;
// Kernel to launch
// based on the default
// setting in the
// Collective Builder
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
// ----------------------------------------------------------------------------
// Kernel template — Tile/Cluster shapes
// ----------------------------------------------------------------------------
template
<
class
TileShape_MN
,
class
ClusterShape_MNK
>
struct
W4A8GemmKernel
{
using
TileShape
=
decltype
(
cute
::
append
(
TileShape_MN
{},
cute
::
Int
<
TileShapeK
>
{}));
using
ClusterShape
=
ClusterShape_MNK
;
// Epilogue per-tok, per-chan scales
using
ChTokScalesEpilogue
=
typename
vllm
::
c3x
::
ScaledEpilogue
<
ElementAccumulator
,
ElementD
,
TileShape
>
;
using
EVTCompute
=
typename
ChTokScalesEpilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementSChannel
,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C
// matrix. We can enable this if beta == 0 by changing ElementC to
// void below.
ElementC
,
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutC
>::
type
,
AlignmentC
,
ElementD
,
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
,
AlignmentD
,
EpilogueSchedule
,
// This is the only epi supporting the required
// swap + transpose.
EVTCompute
>::
CollectiveOp
;
// The Scale information must get paired with the operand that will be scaled.
// In this example, B is scaled so we make a tuple of B's information and the
// scale information.
using
CollectiveMainloopShuffled
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
cute
::
tuple
<
ElementB
,
cutlass
::
Array
<
ElementScale
,
ScalePackSize
>>
,
LayoutB_Reordered
,
AlignmentB
,
ElementA
,
LayoutA_Transpose
,
AlignmentA
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernelShuffled
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloopShuffled
,
CollectiveEpilogue
>
;
using
GemmShuffled
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernelShuffled
>
;
using
StrideC
=
typename
GemmKernelShuffled
::
StrideC
;
using
StrideD
=
typename
GemmKernelShuffled
::
StrideD
;
using
StrideS
=
typename
CollectiveMainloopShuffled
::
StrideScale
;
static
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
// already packed
torch
::
Tensor
const
&
group_scales
,
// already packed
int64_t
group_size
,
torch
::
Tensor
const
&
channel_scales
,
torch
::
Tensor
const
&
token_scales
,
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
)
{
// TODO: param validation
int
m
=
A
.
size
(
0
);
int
k
=
A
.
size
(
1
);
int
n
=
B
.
size
(
1
);
// Allocate output
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
auto
device
=
A
.
device
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
device
.
index
());
torch
::
Tensor
D
=
torch
::
empty
({
m
,
n
},
torch
::
TensorOptions
()
.
dtype
(
equivalent_scalar_type_v
<
ElementD
>
)
.
device
(
device
));
// prepare arg pointers
auto
A_ptr
=
static_cast
<
MmaType
const
*>
(
A
.
const_data_ptr
());
auto
B_ptr
=
static_cast
<
QuantType
const
*>
(
B
.
const_data_ptr
());
auto
D_ptr
=
static_cast
<
ElementD
*>
(
D
.
data_ptr
());
// can we avoid harcode the 8 here
auto
S_ptr
=
static_cast
<
cutlass
::
Array
<
ElementScale
,
ScalePackSize
>
const
*>
(
group_scales
.
const_data_ptr
());
// runtime layout for B
auto
shape_B
=
cute
::
make_shape
(
n
,
k
,
1
);
LayoutB_Reordered
layout_B_reordered
=
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
// strides
int
const
scale_k
=
cutlass
::
ceil_div
(
k
,
group_size
);
StrideA
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
// Reverse stride here due to swap and transpose
StrideD
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
n
,
m
,
1
));
StrideS
stride_S
=
cutlass
::
make_cute_packed_stride
(
StrideS
{},
cute
::
make_shape
(
n
,
scale_k
,
1
));
// Create a structure of gemm kernel arguments suitable for invoking an
// instance of Gemm auto arguments =
// args_from_options<GemmShuffled>(options);
/// Populates a Gemm::Arguments structure from the given arguments
/// Swap the A and B tensors, as well as problem shapes here.
using
Args
=
typename
GemmShuffled
::
Arguments
;
using
MainloopArguments
=
typename
GemmKernelShuffled
::
MainloopArguments
;
using
EpilogueArguments
=
typename
GemmKernelShuffled
::
EpilogueArguments
;
MainloopArguments
mainloop_arguments
{
B_ptr
,
layout_B_reordered
,
A_ptr
,
stride_A
,
S_ptr
,
stride_S
,
group_size
};
EpilogueArguments
epilogue_arguments
{
ChTokScalesEpilogue
::
prepare_args
(
channel_scales
,
token_scales
),
nullptr
,
{},
// no C
D_ptr
,
stride_D
};
Args
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
n
,
m
,
k
,
1
},
// shape
mainloop_arguments
,
epilogue_arguments
};
// Workspace
size_t
workspace_size
=
GemmShuffled
::
get_workspace_size
(
arguments
);
torch
::
Tensor
workspace
=
torch
::
empty
(
workspace_size
,
torch
::
TensorOptions
().
dtype
(
torch
::
kU8
).
device
(
device
));
// Run GEMM
GemmShuffled
gemm
;
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
stream
));
return
D
;
}
};
// ----------------------------------------------------------------------------
// Kernel instantiations and dispatch logic
// ----------------------------------------------------------------------------
using
Kernel_256x128_1x1x1
=
W4A8GemmKernel
<
Shape
<
_256
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_256x64_1x1x1
=
W4A8GemmKernel
<
Shape
<
_256
,
_64
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_256x32_1x1x1
=
W4A8GemmKernel
<
Shape
<
_256
,
_32
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_256x16_1x1x1
=
W4A8GemmKernel
<
Shape
<
_256
,
_16
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_128x256_2x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_256
>
,
Shape
<
_2
,
_1
,
_1
>>
;
using
Kernel_128x256_1x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_256
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_128x128_1x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_128x64_1x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_64
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_128x32_1x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_32
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
Kernel_128x16_1x1x1
=
W4A8GemmKernel
<
Shape
<
_128
,
_16
>
,
Shape
<
_1
,
_1
,
_1
>>
;
torch
::
Tensor
mm_dispatch
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
// already packed
torch
::
Tensor
const
&
group_scales
,
// already packed
int64_t
group_size
,
torch
::
Tensor
const
&
channel_scales
,
torch
::
Tensor
const
&
token_scales
,
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
const
std
::
string
&
schedule
)
{
if
(
schedule
==
"256x128_1x1x1"
)
{
return
Kernel_256x128_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"256x64_1x1x1"
)
{
return
Kernel_256x64_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"256x32_1x1x1"
)
{
return
Kernel_256x32_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"256x16_1x1x1"
)
{
return
Kernel_256x16_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x256_2x1x1"
)
{
return
Kernel_128x256_2x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x256_1x1x1"
)
{
return
Kernel_128x256_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x128_1x1x1"
)
{
return
Kernel_128x128_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x64_1x1x1"
)
{
return
Kernel_128x64_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x32_1x1x1"
)
{
return
Kernel_128x32_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
else
if
(
schedule
==
"128x16_1x1x1"
)
{
return
Kernel_128x16_1x1x1
::
mm
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
);
}
TORCH_CHECK
(
false
,
"Unknown W4A8 schedule: "
,
schedule
);
return
{};
}
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
// already packed
torch
::
Tensor
const
&
group_scales
,
// already packed
int64_t
group_size
,
torch
::
Tensor
const
&
channel_scales
,
torch
::
Tensor
const
&
token_scales
,
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
std
::
optional
<
std
::
string
>
maybe_schedule
)
{
// requested a specific schedule
if
(
maybe_schedule
)
{
return
mm_dispatch
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
,
*
maybe_schedule
);
}
std
::
string
schedule
;
int
M
=
A
.
size
(
0
);
int
K
=
A
.
size
(
1
);
int
N
=
B
.
size
(
1
);
// heuristic
if
(
M
<=
16
)
{
schedule
=
(
K
==
16384
&&
N
==
18432
)
?
"256x16_1x1x1"
:
"128x16_1x1x1"
;
}
else
if
(
M
<=
32
)
{
schedule
=
(
K
==
16384
&&
N
==
18432
)
?
"256x32_1x1x1"
:
"128x32_1x1x1"
;
}
else
if
(
M
<=
64
)
{
if
(
K
==
16384
&&
N
==
18432
)
schedule
=
"256x64_1x1x1"
;
else
if
(
N
<=
8192
&&
K
<=
8192
)
schedule
=
"128x32_1x1x1"
;
else
schedule
=
"128x64_1x1x1"
;
}
else
if
(
M
<=
128
)
{
if
(
K
==
16384
&&
N
==
18432
)
schedule
=
"256x128_1x1x1"
;
else
if
(
N
<=
8192
)
schedule
=
"128x64_1x1x1"
;
else
schedule
=
"128x128_1x1x1"
;
}
else
if
(
M
<=
256
)
{
if
(
N
<=
4096
)
schedule
=
"128x64_1x1x1"
;
else
if
(
N
<=
8192
)
schedule
=
"128x128_1x1x1"
;
else
schedule
=
"128x256_1x1x1"
;
}
else
if
(
M
<=
512
&&
N
<=
4096
)
{
schedule
=
"128x128_1x1x1"
;
}
else
if
(
M
<=
1024
)
{
schedule
=
"128x256_1x1x1"
;
}
else
{
schedule
=
"128x256_2x1x1"
;
}
return
mm_dispatch
(
A
,
B
,
group_scales
,
group_size
,
channel_scales
,
token_scales
,
maybe_out_type
,
schedule
);
}
// ----------------------------------------------------------------------------
// Pre-processing utils
// ----------------------------------------------------------------------------
torch
::
Tensor
pack_scale_fp8
(
torch
::
Tensor
const
&
scales
)
{
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
scales
.
is_contiguous
());
TORCH_CHECK
(
scales
.
is_cuda
());
auto
packed_scales
=
torch
::
empty
(
{
scales
.
numel
()
*
ScalePackSize
},
torch
::
TensorOptions
().
dtype
(
scales
.
dtype
()).
device
(
scales
.
device
()));
auto
scales_ptr
=
static_cast
<
MmaType
const
*>
(
scales
.
const_data_ptr
());
auto
packed_scales_ptr
=
static_cast
<
cutlass
::
Array
<
ElementScale
,
ScalePackSize
>*>
(
packed_scales
.
data_ptr
());
cutlass
::
pack_scale_fp8
(
scales_ptr
,
packed_scales_ptr
,
scales
.
numel
());
return
packed_scales
;
}
torch
::
Tensor
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
B
)
{
TORCH_CHECK
(
B
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
B
.
dim
()
==
2
);
torch
::
Tensor
B_packed
=
torch
::
empty_like
(
B
);
int
k
=
B
.
size
(
0
)
*
PackFactor
;
// logical k
int
n
=
B
.
size
(
1
);
auto
B_ptr
=
static_cast
<
QuantType
const
*>
(
B
.
const_data_ptr
());
auto
B_packed_ptr
=
static_cast
<
QuantType
*>
(
B_packed
.
data_ptr
());
auto
shape_B
=
cute
::
make_shape
(
n
,
k
,
1
);
auto
layout_B
=
make_layout
(
shape_B
,
LayoutRight
{});
// row major
LayoutB_Reordered
layout_B_reordered
=
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
cutlass
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
cutlass
::
reorder_tensor
(
B_packed_ptr
,
layout_B
,
layout_B_reordered
);
return
B_packed
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"cutlass_w4a8_mm"
,
&
mm
);
m
.
impl
(
"cutlass_pack_scale_fp8"
,
&
pack_scale_fp8
);
m
.
impl
(
"cutlass_encode_and_reorder_int4b"
,
&
encode_and_reorder_int4b
);
}
}
// namespace vllm::cutlass_w4a8
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
e76e2335
...
@@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"
);
"SymInt size_n, int num_bits) -> Tensor"
);
// conditionally compiled so impl registrations are in source file
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops
.
def
(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor"
,
{
stride_tag
});
// pack scales
ops
.
def
(
"cutlass_pack_scale_fp8(Tensor scales) -> Tensor"
);
// encode and reorder weight matrix
ops
.
def
(
"cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"
);
// conditionally compiled so impl registration is in source file
#endif
#endif
// Dequantization for GGML.
// Dequantization for GGML.
...
...
tests/kernels/quantization/test_cutlass_w4a8.py
0 → 100644
View file @
e76e2335
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the CUTLASS W4A8 kernel.
Run `pytest tests/kernels/test_cutlass_w4a8.py`.
"""
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU
=
current_platform
.
get_device_capability
()[
0
]
>=
9
MNK_SHAPES
=
[(
1
,
128
,
128
),
(
1
,
512
,
1024
),
(
1
,
4096
,
4096
),
(
1
,
8192
,
28672
),
(
13
,
8192
,
4096
),
(
26
,
4096
,
8192
),
(
64
,
4096
,
4096
),
(
64
,
8192
,
28672
),
(
257
,
128
,
4096
),
(
257
,
4096
,
4096
),
(
1024
,
4096
,
8192
),
(
1024
,
8192
,
4096
)]
# TODO(czhu): get supported schedules from fn
SCHEDULES
=
[
'128x16_1x1x1'
,
'256x16_1x1x1'
,
'128x32_1x1x1'
,
'256x32_1x1x1'
,
'128x64_1x1x1'
,
'256x64_1x1x1'
,
'128x128_1x1x1'
,
'256x128_1x1x1'
,
'128x256_1x1x1'
,
'128x256_2x1x1'
]
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
Optional
[
torch
.
dtype
]
@
dataclass
class
Tensors
:
w_ref
:
torch
.
Tensor
a_ref
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
w_g_s
:
torch
.
Tensor
w_ch_s
:
torch
.
Tensor
w_tok_s
:
torch
.
Tensor
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
TestTypeTuple
=
tuple
[
list
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
bool
]
TEST_TYPES
=
[
*
(
TypeConfig
(
act_type
=
torch
.
float8_e4m3fn
,
weight_type
=
w_type
,
output_type
=
o_type
,
group_scale_type
=
torch
.
float8_e4m3fn
,
channel_scale_type
=
torch
.
float32
,
token_scale_type
=
torch
.
float32
)
for
w_type
in
[
scalar_types
.
int4
]
# TODO(czhu): fp16 out type
for
o_type
in
[
torch
.
bfloat16
]),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU
=
current_platform
.
has_device_capability
(
90
)
# For testing quantized linear kernels
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
cutlass_quantize_and_pack
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
],
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
)
# since scales are cast to fp8, we need to compute w_ref this way
w_ref
=
((
w_q
).
to
(
torch
.
float32
)
*
w_s
.
to
(
atype
).
to
(
torch
.
float32
).
repeat_interleave
(
group_size
,
dim
=
0
)).
to
(
atype
)
# bit mask prevents sign extending int4 when packing
w_q
=
pack_rows
(
w_q
&
0x0F
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_packed
=
ops
.
cutlass_encode_and_reorder_int4b
(
w_q
)
w_s_packed
=
ops
.
cutlass_pack_scale_fp8
(
w_s
.
to
(
atype
))
return
w_ref
,
w_q_packed
,
w_s_packed
,
w_zp
def
create_test_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
])
->
Tensors
:
m
,
n
,
k
=
shape
print
(
"create_test_tensors, shape:"
,
shape
,
"types:"
,
types
,
"group_size:"
,
group_size
)
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
w
=
to_fp8
(
torch
.
randn
((
k
,
n
),
device
=
"cuda"
))
if
types
.
group_scale_type
is
not
None
:
w
=
w
.
to
(
types
.
group_scale_type
)
if
w
.
dtype
.
itemsize
==
1
:
w
=
w
.
to
(
torch
.
float16
)
w_ref
,
w_q_packed
,
w_s
,
_
=
cutlass_quantize_and_pack
(
a
.
dtype
,
w
,
types
.
weight_type
,
types
.
group_scale_type
,
group_size
,
False
)
a_ref
=
a
.
to
(
torch
.
float32
)
w_ref
=
w_ref
.
to
(
torch
.
float32
)
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s
=
torch
.
randn
((
m
,
),
device
=
'cuda'
,
dtype
=
types
.
token_scale_type
)
# weights are already per-group quantized, use placeholder here
w_ch_s
=
torch
.
ones
((
n
,
),
device
=
'cuda'
,
dtype
=
types
.
channel_scale_type
)
return
Tensors
(
w_ref
=
w_ref
,
a_ref
=
a_ref
,
a
=
a
,
w_q
=
w_q_packed
,
w_g_s
=
w_s
,
w_ch_s
=
w_ch_s
,
w_tok_s
=
w_tok_s
)
def
mm_test_helper
(
types
:
TypeConfig
,
tensors
:
Tensors
,
group_size
:
Optional
[
int
]
=
None
,
schedule
:
Optional
[
str
]
=
None
):
# CUTLASS upstream uses fp8 with fastaccum as reference
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
output_ref
=
torch
.
_scaled_mm
(
tensors
.
a_ref
.
to
(
types
.
act_type
),
tensors
.
w_ref
.
to
(
types
.
act_type
).
t
().
contiguous
().
t
(),
# col major
tensors
.
w_tok_s
.
unsqueeze
(
1
),
tensors
.
w_ch_s
.
unsqueeze
(
0
),
out_dtype
=
types
.
output_type
,
use_fast_accum
=
True
)
output
=
ops
.
cutlass_w4a8_mm
(
a
=
tensors
.
a
,
b_q
=
tensors
.
w_q
,
b_group_scales
=
tensors
.
w_g_s
,
b_group_size
=
group_size
,
b_channel_scales
=
tensors
.
w_ch_s
,
a_token_scales
=
tensors
.
w_tok_s
,
)
print
(
output
)
print
(
output_ref
)
torch
.
testing
.
assert_close
(
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"CUTLASS W4A8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
,
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
@
pytest
.
mark
.
parametrize
(
"schedule"
,
SCHEDULES
)
def
test_cutlass_w4a8
(
shape
,
types
:
TypeConfig
,
schedule
):
group_sizes
=
[
128
]
for
group_size
in
group_sizes
:
tensors
=
create_test_tensors
(
shape
,
types
,
group_size
)
mm_test_helper
(
types
,
tensors
,
group_size
,
schedule
)
# Test to make sure cuda graphs work
class
W4A8Layer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
kwargs
=
kwargs
def
forward
(
self
,
a
):
return
ops
.
cutlass_w4a8_mm
(
a
=
a
,
**
self
.
kwargs
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"CUTLASS W4A8 is not supported on this GPU type."
)
def
test_w4a8_cuda_graph
():
m
,
n
,
k
=
512
,
4096
,
4096
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
b
=
to_fp8
(
torch
.
randn
((
k
,
n
),
device
=
"cuda"
))
wtype
=
scalar_types
.
int4
stype
=
torch
.
float8_e4m3fn
group_size
=
128
zero_points
=
False
w_ref
,
w_q_packed
,
w_s
,
_
=
cutlass_quantize_and_pack
(
a
.
dtype
,
b
.
to
(
torch
.
float16
),
wtype
,
stype
,
group_size
,
zero_points
)
w_tok_s
=
torch
.
randn
((
m
,
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
w_ch_s
=
torch
.
ones
((
n
,
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
# Construct a trivial model with a single layer that calls the kernel
model
=
W4A8Layer
(
b_q
=
w_q_packed
,
b_group_scales
=
w_s
,
b_group_size
=
group_size
,
b_channel_scales
=
w_ch_s
,
a_token_scales
=
w_tok_s
,
)
output_ref
=
torch
.
_scaled_mm
(
a
,
w_ref
.
to
(
a
.
dtype
).
t
().
contiguous
().
t
(),
# col major
w_tok_s
.
unsqueeze
(
1
),
w_ch_s
.
unsqueeze
(
0
),
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
)
# Run the model with a cuda graph
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
output
=
model
(
a
)
output
.
zero_
()
g
.
replay
()
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
vllm/_custom_ops.py
View file @
e76e2335
...
@@ -474,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -474,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return
torch
.
empty_like
(
b_q_weight
,
return
torch
.
empty_like
(
b_q_weight
,
memory_format
=
torch
.
contiguous_format
)
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::cutlass_w4a8_mm"
)
def
cutlass_w4a8_mm_fake
(
a
:
torch
.
Tensor
,
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
b_q
:
torch
.
Tensor
,
b_group_scales
:
torch
.
Tensor
,
b_group_size
:
int
,
b_channel_scales
:
torch
.
Tensor
,
a_token_scales
:
torch
.
Tensor
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
,
maybe_schedule
:
Optional
[
str
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
size
(
0
)
n
=
b_q
.
size
(
1
)
out_dtype
=
out_type
if
out_type
is
not
None
else
torch
.
bfloat16
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
out_dtype
)
@
register_fake
(
"_C::cutlass_pack_scale_fp8"
)
def
cutlass_pack_scale_fp8_fake
(
scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
scales
,
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::cutlass_encode_and_reorder_int4b"
)
def
cutlass_encode_and_reorder_int4b_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
...
@@ -1032,6 +1056,30 @@ def machete_prepack_B(
...
@@ -1032,6 +1056,30 @@ def machete_prepack_B(
group_scales_type
)
group_scales_type
)
# CUTLASS W4A8
def
cutlass_w4a8_mm
(
a
:
torch
.
Tensor
,
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
b_q
:
torch
.
Tensor
,
b_group_scales
:
torch
.
Tensor
,
b_group_size
:
int
,
b_channel_scales
:
torch
.
Tensor
,
a_token_scales
:
torch
.
Tensor
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
,
maybe_schedule
:
Optional
[
str
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
cutlass_w4a8_mm
(
a
,
b_q
,
b_group_scales
,
b_group_size
,
b_channel_scales
,
a_token_scales
,
out_type
,
maybe_schedule
)
def
cutlass_pack_scale_fp8
(
scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
cutlass_pack_scale_fp8
(
scales
)
def
cutlass_encode_and_reorder_int4b
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b
(
b
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
register_fake
(
"_C::permute_cols"
)
@
register_fake
(
"_C::permute_cols"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
e76e2335
...
@@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
CompressedTensorsScheme
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsScheme
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A8
Int
,
CompressedTensorsW4A
16Fp4
,
CompressedTensorsW4A8
Fp8
,
CompressedTensorsW4A
8Int
,
CompressedTensorsW4A16
Sparse2
4
,
CompressedTensorsW
8A8Fp8
,
CompressedTensorsW4A16
Fp
4
,
CompressedTensorsW
4A16Sparse24
,
CompressedTensorsW8A8
Int
8
,
CompressedTensorsW8A
16Fp
8
,
CompressedTensorsW8A8
Fp
8
,
CompressedTensorsW8A
8Int
8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
is_activation_quantization_format
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
should_ignore_layer
)
...
@@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig):
format
format
)
if
format
is
not
None
else
is_activation_quantization_format
(
)
if
format
is
not
None
else
is_activation_quantization_format
(
quant_format
)
quant_format
)
if
act_quant_format
:
# TODO(czhu): w4a8fp8 is in packed-quantized format
input_activations
=
quant_config
.
get
(
"input_activations"
)
# but needs input activation quantization
input_activations
=
quant_config
.
get
(
"input_activations"
)
if
act_quant_format
or
input_activations
:
# The only case where we have activation quant supported
# The only case where we have activation quant supported
# but no input_activations provided in the config
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# should be w8a16fp8 w8a16fp8 can also run for cases where
...
@@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
return
is_symmetric_activation
and
is_per_tensor_activation
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w4a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
if
not
weight_quant
or
not
input_quant
:
return
False
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
is_activation_8_bits
=
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
is_token
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
.
value
)
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
)
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w4a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
...
@@ -405,6 +429,13 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -405,6 +429,13 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_fp4a16_nvfp4
(
weight_quant
,
input_quant
):
if
self
.
_is_fp4a16_nvfp4
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A16Fp4
()
return
CompressedTensorsW4A16Fp4
()
if
self
.
_is_fp8_w4a8_sm90
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A8Fp8
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
symmetric
=
weight_quant
.
symmetric
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
e76e2335
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_w4a4_nvfp4
import
CompressedTensorsW4A4Fp4
from
.compressed_tensors_w4a4_nvfp4
import
CompressedTensorsW4A4Fp4
from
.compressed_tensors_w4a8_fp8
import
CompressedTensorsW4A8Fp8
from
.compressed_tensors_w4a8_int
import
CompressedTensorsW4A8Int
from
.compressed_tensors_w4a8_int
import
CompressedTensorsW4A8Int
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
CompressedTensorsW4A16Sparse24
)
...
@@ -21,5 +22,6 @@ __all__ = [
...
@@ -21,5 +22,6 @@ __all__ = [
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"CompressedTensors24"
,
"CompressedTensorsW4A16Fp4"
,
"CompressedTensors24"
,
"CompressedTensorsW4A16Fp4"
,
"CompressedTensorsW4A4Fp4"
,
"CompressedTensorsW4A8Int"
"CompressedTensorsW4A4Fp4"
,
"CompressedTensorsW4A8Int"
,
"CompressedTensorsW4A8Fp8"
]
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
0 → 100644
View file @
e76e2335
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
ActivationOrdering
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_repeat_scales_on_all_ranks
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
# yapf: enable
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsW4A8Fp8"
]
W4A8_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
int4
,
}
W4A8_SUPPORTED_BITS
=
list
(
W4A8_SUPPORTED_TYPES_MAP
.
keys
())
class
CompressedTensorsW4A8Fp8
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
,
symmetric
:
Optional
[
bool
]
=
True
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
symmetric
=
symmetric
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
!=
128
or
self
.
strategy
!=
"group"
:
raise
ValueError
(
"W4A8 kernels require group quantization "
\
"with group size 128"
)
if
num_bits
not
in
W4A8_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
W4A8_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
W4A8_SUPPORTED_TYPES_MAP
[
num_bits
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# hopper
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size
:
int
,
input_size
:
int
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_type
,
act_type
=
torch
.
float8_e4m3fn
,
# always use fp8(e4m3)
group_size
=
self
.
group_size
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsW4A8Fp8"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# If group_size is -1, we are in channelwise case.
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
assert
input_size_per_partition
%
group_size
==
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
packed_factor
=
self
.
pack_factor
,
packed_dim
=
1
,
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
params_dtype
,
)
}
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_s_param_name
=
"weight_scale"
,
w_zp_param_name
=
"weight_zero_point"
,
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
e76e2335
...
@@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
...
@@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
BitBLASLinearKernel
)
BitBLASLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.conch
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.conch
import
(
# noqa: E501
ConchLinearKernel
)
ConchLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass
import
(
# noqa: E501
CutlassW4A8LinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit
import
(
# noqa: E501
Dynamic4bitLinearKernel
)
Dynamic4bitLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
...
@@ -24,6 +26,7 @@ from vllm.platforms import current_platform
...
@@ -24,6 +26,7 @@ from vllm.platforms import current_platform
# in priority/performance order (when available)
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
list
[
type
[
MPLinearKernel
]]
=
[
_POSSIBLE_KERNELS
:
list
[
type
[
MPLinearKernel
]]
=
[
CutlassW4A8LinearKernel
,
MacheteLinearKernel
,
MacheteLinearKernel
,
AllSparkLinearKernel
,
AllSparkLinearKernel
,
MarlinLinearKernel
,
MarlinLinearKernel
,
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
0 → 100644
View file @
e76e2335
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
CutlassW4A8LinearKernel
(
MPLinearKernel
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# dynamic per-tok fp8 activation quantization
self
.
quant_fp8
=
QuantFP8
(
static
=
False
,
group_shape
=
GroupShape
.
PER_TOKEN
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
90
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
if
not
current_platform
.
is_cuda
():
return
False
,
"CUTLASS only supported on CUDA"
if
not
current_platform
.
is_device_capability
(
90
):
return
False
,
"CUTLASS W4A8 requires compute capability of 90 "
\
"(Hopper)"
if
c
.
act_type
!=
torch
.
float8_e4m3fn
:
return
False
,
"CUTLASS W4A8 only supports FP8 (e4m3) activations"
if
c
.
has_g_idx
:
return
False
,
"Act reordering not supported by CUTLASS W4A8"
if
c
.
zero_points
:
return
False
,
"Zero points not supported by CUTLASS W4A8"
if
c
.
weight_type
!=
scalar_types
.
int4
:
return
False
,
f
"Quant type (
{
c
.
weight_type
}
) not supported by "
\
"CUTLASS W4A8, only supported int4"
# TODO(czhu): support -1 (column-wise)
if
c
.
group_size
!=
128
:
return
False
,
"Only group_size 128 is supported"
in_features
,
out_features
=
c
.
partition_weight_shape
if
in_features
%
128
or
out_features
%
128
:
return
False
,
"K and N must be divisible by 128, got "
\
f
"
{
c
.
partition_weight_shape
}
"
return
True
,
None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
# TODO(czhu): optimize speed/mem usage
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
ops
.
cutlass_encode_and_reorder_int4b
(
x
.
data
.
t
().
contiguous
().
t
())
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
x
.
data
.
contiguous
().
to
(
torch
.
float8_e4m3fn
)
x
.
data
=
ops
.
cutlass_pack_scale_fp8
(
x
.
data
)
return
x
# Encode/reorder weights and pack scales
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
# TODO(czhu): support loading channel scales
self
.
w_ch_s
=
torch
.
ones
((
c
.
partition_weight_shape
[
1
],
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
bias
is
None
,
"bias not supported by CUTLASS W4A8"
c
=
self
.
config
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
x_2d
,
act_scales
=
self
.
quant_fp8
(
x_2d
)
output
=
ops
.
cutlass_w4a8_mm
(
a
=
x_2d
,
b_q
=
w_q
,
b_group_scales
=
w_s
,
b_group_size
=
c
.
group_size
,
a_token_scales
=
act_scales
,
b_channel_scales
=
self
.
w_ch_s
)
return
output
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment