Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6227c22
Unverified
Commit
f6227c22
authored
Dec 08, 2025
by
czhu-cohere
Committed by
GitHub
Dec 08, 2025
Browse files
[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
ea657f20
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1981 additions
and
96 deletions
+1981
-96
CMakeLists.txt
CMakeLists.txt
+4
-1
csrc/ops.h
csrc/ops.h
+2
-1
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
+104
-0
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
+483
-0
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+3
-67
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
+90
-0
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
+11
-0
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
+5
-3
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+5
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+25
-1
tests/kernels/quantization/test_cutlass_w4a8.py
tests/kernels/quantization/test_cutlass_w4a8.py
+41
-5
tests/kernels/quantization/test_cutlass_w4a8_moe.py
tests/kernels/quantization/test_cutlass_w4a8_moe.py
+340
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+89
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+4
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+29
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+401
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+339
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
+4
-11
No files found.
CMakeLists.txt
View file @
f6227c22
...
...
@@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection
(
W4A8_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
)
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
...
...
csrc/ops.h
View file @
f6227c22
...
...
@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
...
...
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
0 → 100644
View file @
f6227c22
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
// ElementB is int32 (packed int4)
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
template
<
typename
ElementA
,
typename
ElementB
,
typename
ElementC
,
typename
ElementAccumulator
,
typename
ElementGroupScale
>
__global__
void
get_group_gemm_starts
(
int64_t
*
expert_offsets
,
ElementA
**
a_offsets
,
ElementB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementAccumulator
**
b_scales_offsets
,
ElementGroupScale
**
b_group_scales_offsets
,
ElementA
*
a_base_as_int
,
ElementB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAccumulator
*
a_scales_base_as_int
,
ElementAccumulator
*
b_scales_base_as_int
,
ElementGroupScale
*
b_group_scales_base_as_int
,
int64_t
n
,
int64_t
k
,
int64_t
scale_k
)
{
int
expert_id
=
threadIdx
.
x
;
int64_t
expert_offset
=
expert_offsets
[
expert_id
];
// same as w8a8
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
k
;
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
expert_offset
;
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
(
n
*
expert_id
);
// w4a8 specific
constexpr
int
pack_factor
=
8
;
// pack 8 int4 into int32
b_offsets
[
expert_id
]
=
b_base_as_int
+
(
expert_id
*
k
*
n
/
pack_factor
);
b_group_scales_offsets
[
expert_id
]
=
b_group_scales_base_as_int
+
(
expert_id
*
scale_k
*
n
);
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<int32_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
b_group_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<int32_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
b_group_scales.data_ptr()), \
n, k, scale_k); \
}
namespace
{
void
run_get_group_gemm_starts
(
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
torch
::
Tensor
&
b_group_scales_ptrs
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_group_scales
,
const
int64_t
b_group_size
)
{
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kInt32
);
// int4 8x packed into int32
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_group_scales
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
// the underlying torch type is e4m3
TORCH_CHECK
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
);
// only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt64
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
// logical k, n
int64_t
n
=
out_tensors
.
size
(
1
);
int64_t
k
=
a_tensors
.
size
(
1
);
int64_t
scale_k
=
cutlass
::
ceil_div
(
k
,
b_group_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
if
(
false
)
{
}
__CALL_GET_STARTS_KERNEL
(
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
)
__CALL_GET_STARTS_KERNEL
(
torch
::
kFloat16
,
half
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
}
// namespace
\ No newline at end of file
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
0 → 100644
View file @
f6227c22
#include <vector>
#include <tuple>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
namespace
vllm
::
cutlass_w4a8_moe
{
using
namespace
cute
;
// -------------------------------------------------------------------------------------
// Static configuration shared across all instantiations
// -------------------------------------------------------------------------------------
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int
,
int
,
int
>>
;
// <M,N,K> per
// group
using
MmaType
=
cutlass
::
float_e4m3_t
;
using
QuantType
=
cutlass
::
int4b_t
;
constexpr
int
TileShapeK
=
128
*
8
/
sizeof_bits
<
MmaType
>::
value
;
static
int
constexpr
PackFactor
=
8
;
// 8 int4 packed into int32
// A matrix configuration
using
ElementA
=
MmaType
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
// Layout type for A matrix operand
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Alignment of A matrix in units of
// elements (up to 16 bytes)
// B matrix configuration
using
ElementB
=
QuantType
;
// Element type for B matrix operand
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
// Layout type for B matrix operand
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input
// layouts
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
using
StrideA
=
cute
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideA_t
<
LayoutA
*>>
;
using
StrideB
=
cute
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideB_t
<
LayoutB
*>>
;
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in
// contiguous locations in global memory. It specifies the reordering within a
// single warp's fragment
using
LayoutAtomQuant
=
decltype
(
cutlass
::
compute_memory_reordering_atom
<
MmaType
>
());
using
LayoutB_Reordered
=
decltype
(
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
Layout
<
Shape
<
int
,
int
,
Int
<
1
>>
,
StrideB
>
{}));
using
ElementScale
=
cutlass
::
float_e4m3_t
;
using
LayoutScale
=
cutlass
::
layout
::
RowMajor
;
// C/D matrix configuration
using
ElementC
=
cutlass
::
bfloat16_t
;
// Element type for C and D matrix operands
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
// Layout type for C and D matrix operands
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)
// D matrix configuration
using
ElementD
=
ElementC
;
using
LayoutD
=
LayoutC
;
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Core kernel configurations
using
ElementAccumulator
=
float
;
// Element type for internal accumulation
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
// Tag indicating the minimum SM that
// supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
// per-channel and per-token scales for epilogue
using
ElementSChannel
=
float
;
template
<
class
TileShape_MN
,
class
ClusterShape_MNK
,
class
KernelSchedule
,
class
EpilogueSchedule
>
struct
W4A8GroupedGemmKernel
{
using
TileShape
=
decltype
(
cute
::
append
(
TileShape_MN
{},
cute
::
Int
<
TileShapeK
>
{}));
using
ClusterShape
=
ClusterShape_MNK
;
// per-channel, per-token scales epilogue
using
ChTokScalesEpilogue
=
typename
vllm
::
c3x
::
ScaledEpilogueArray
<
ElementAccumulator
,
ElementD
,
TileShape
>
;
using
EVTCompute
=
typename
ChTokScalesEpilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementSChannel
,
ElementC
,
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutC
>::
type
*
,
AlignmentC
,
ElementD
,
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
*
,
AlignmentD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
// =========================================================== MIXED INPUT
// WITH SCALES
// ===========================================================================
// The Scale information must get paired with the operand that will be scaled.
// In this example, B is scaled so we make a tuple of B's information and the
// scale information.
using
CollectiveMainloopShuffled
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
cute
::
tuple
<
ElementB
,
cutlass
::
Array
<
ElementScale
,
8
>>
,
LayoutB_Reordered
*
,
AlignmentB
,
ElementA
,
LayoutA_Transpose
*
,
AlignmentA
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernelShuffled
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloopShuffled
,
CollectiveEpilogue
>
;
using
GemmShuffled
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernelShuffled
>
;
using
StrideC
=
typename
GemmKernelShuffled
::
InternalStrideC
;
using
StrideD
=
typename
GemmKernelShuffled
::
InternalStrideD
;
using
StrideC_ref
=
cutlass
::
detail
::
TagToStrideC_t
<
LayoutC
>
;
using
StrideD_ref
=
cutlass
::
detail
::
TagToStrideC_t
<
LayoutD
>
;
using
StrideS
=
typename
CollectiveMainloopShuffled
::
StrideScale
;
using
StrideS_ref
=
cutlass
::
detail
::
TagToStrideB_t
<
LayoutScale
>
;
// static asserts for passing in strides/layouts
// pack to 2x int64
static_assert
(
sizeof
(
StrideS
)
==
2
*
sizeof
(
int64_t
));
// pack to 3xint32,
static_assert
(
sizeof
(
LayoutB_Reordered
)
%
sizeof
(
int32_t
)
==
0
,
"LayoutB_Reordered size must be divisible by 4 bytes"
);
static
void
grouped_mm
(
torch
::
Tensor
&
out_tensors
,
const
torch
::
Tensor
&
a_tensors
,
const
torch
::
Tensor
&
b_tensors
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
b_group_scales
,
const
int64_t
b_group_size
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
problem_sizes_torch
,
const
torch
::
Tensor
&
a_strides
,
const
torch
::
Tensor
&
b_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
group_scale_strides
)
{
auto
device
=
a_tensors
.
device
();
auto
device_id
=
device
.
index
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
device_id
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
int
n
=
static_cast
<
int
>
(
b_tensors
.
size
(
1
));
int
k
=
static_cast
<
int
>
(
b_tensors
.
size
(
2
))
*
PackFactor
;
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
device
);
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_group_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts
(
expert_offsets
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
b_group_scales_ptrs
,
a_tensors
,
b_tensors
,
out_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
);
// construct args
using
Args
=
typename
GemmShuffled
::
Arguments
;
using
MainloopArguments
=
typename
GemmKernelShuffled
::
MainloopArguments
;
using
EpilogueArguments
=
typename
GemmKernelShuffled
::
EpilogueArguments
;
Args
arguments
;
ProblemShape
::
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
ProblemShape
::
UnderlyingProblemShape
*>
(
problem_sizes_torch
.
data_ptr
());
ProblemShape
prob_shape
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
};
// SwapAB so B operands come first
MainloopArguments
mainloop_arguments
{
static_cast
<
const
QuantType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
LayoutB_Reordered
*>
(
b_strides
.
data_ptr
()),
static_cast
<
const
MmaType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides
.
data_ptr
()),
static_cast
<
const
cutlass
::
Array
<
ElementScale
,
8
>**>
(
b_group_scales_ptrs
.
data_ptr
()),
static_cast
<
StrideS
*>
(
group_scale_strides
.
data_ptr
()),
static_cast
<
int
>
(
b_group_size
)};
EpilogueArguments
epilogue_arguments
{
// since we are doing SwapAB the channel scales comes first, then token
// scales
ChTokScalesEpilogue
::
prepare_args
(
// see ScaledEpilogueArray
static_cast
<
const
ElementAccumulator
**>
(
b_scales_ptrs
.
data_ptr
()),
// per-channel
static_cast
<
const
ElementAccumulator
**>
(
a_scales_ptrs
.
data_ptr
()),
// per-token
true
,
true
),
nullptr
,
// C
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
()),
// C
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
// D
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
())
// D
};
static
const
cutlass
::
KernelHardwareInfo
hw_info
{
device_id
,
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
device_id
)};
arguments
=
Args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
prob_shape
,
mainloop_arguments
,
epilogue_arguments
,
hw_info
};
// Allocate workspace
size_t
workspace_size
=
GemmShuffled
::
get_workspace_size
(
arguments
);
torch
::
Tensor
workspace
=
torch
::
empty
(
workspace_size
,
torch
::
TensorOptions
().
dtype
(
torch
::
kU8
).
device
(
device
));
// Run GEMM
GemmShuffled
gemm
;
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
stream
));
}
};
// ----------------------------------------------------------------------------
// Kernel instantiations and dispatch logic
// ----------------------------------------------------------------------------
using
Coop
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedCooperative
;
using
CoopEpi
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedCooperative
;
// Kernel_TileShape_ClusterShape_Schedule
using
Kernel_128x16_1x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_128
,
_16
>
,
Shape
<
_1
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_128x16_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_128
,
_16
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x16_1x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_16
>
,
Shape
<
_1
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x16_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_16
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x32_1x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_32
>
,
Shape
<
_1
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x32_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_32
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x64_1x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_64
>
,
Shape
<
_1
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x64_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_64
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x128_1x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_256x128_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_256
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
using
Kernel_128x256_2x1x1_Coop
=
W4A8GroupedGemmKernel
<
Shape
<
_128
,
_256
>
,
Shape
<
_2
,
_1
,
_1
>
,
Coop
,
CoopEpi
>
;
void
mm_dispatch
(
torch
::
Tensor
&
out_tensors
,
const
torch
::
Tensor
&
a_tensors
,
const
torch
::
Tensor
&
b_tensors
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
b_group_scales
,
const
int64_t
b_group_size
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
a_strides
,
const
torch
::
Tensor
&
b_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
group_scale_strides
,
const
std
::
string
&
schedule
)
{
if
(
schedule
==
"Kernel_128x16_1x1x1_Coop"
)
{
Kernel_128x16_1x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_128x16_2x1x1_Coop"
)
{
Kernel_128x16_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x16_1x1x1_Coop"
)
{
Kernel_256x16_1x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x16_2x1x1_Coop"
)
{
Kernel_256x16_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x32_1x1x1_Coop"
)
{
Kernel_256x32_1x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x32_2x1x1_Coop"
)
{
Kernel_256x32_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x64_1x1x1_Coop"
)
{
Kernel_256x64_1x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x64_2x1x1_Coop"
)
{
Kernel_256x64_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x128_1x1x1_Coop"
)
{
Kernel_256x128_1x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_256x128_2x1x1_Coop"
)
{
Kernel_256x128_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
if
(
schedule
==
"Kernel_128x256_2x1x1_Coop"
)
{
Kernel_128x256_2x1x1_Coop
::
grouped_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
);
}
else
{
TORCH_CHECK
(
false
,
"cutlass_w4a8_moe_mm: unknown schedule string: "
,
schedule
);
}
}
void
mm
(
torch
::
Tensor
&
out_tensors
,
const
torch
::
Tensor
&
a_tensors
,
const
torch
::
Tensor
&
b_tensors
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
b_group_scales
,
const
int64_t
b_group_size
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
a_strides
,
const
torch
::
Tensor
&
b_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
group_scale_strides
,
std
::
optional
<
std
::
string
>
maybe_schedule
)
{
// user has specified a schedule
if
(
maybe_schedule
)
{
mm_dispatch
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
,
*
maybe_schedule
);
return
;
}
// use heuristic
int
m_full
=
a_tensors
.
size
(
0
);
int
n
=
b_tensors
.
size
(
1
);
int
k
=
b_tensors
.
size
(
2
)
*
PackFactor
;
// logical k
int
num_experts
=
b_tensors
.
size
(
0
);
// per-expert batch size assuming uniform distribution
int
m_expert
=
m_full
/
num_experts
;
std
::
string
schedule
;
if
(
m_expert
<=
16
)
{
schedule
=
"Kernel_128x16_2x1x1_Coop"
;
}
else
if
(
m_expert
<=
32
)
{
schedule
=
"Kernel_256x32_1x1x1_Coop"
;
}
else
if
(
m_expert
<=
64
)
{
schedule
=
"Kernel_256x64_1x1x1_Coop"
;
}
else
if
(
m_expert
<=
128
)
{
schedule
=
"Kernel_256x128_2x1x1_Coop"
;
}
else
{
// m_expert > 128
schedule
=
"Kernel_128x256_2x1x1_Coop"
;
}
mm_dispatch
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
,
schedule
);
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
b_tensors
)
{
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
b_tensors
.
dim
()
==
3
);
// (experts, n, k)
TORCH_CHECK
(
b_tensors
.
is_contiguous
());
TORCH_CHECK
(
b_tensors
.
is_cuda
());
int
n
=
static_cast
<
int
>
(
b_tensors
.
size
(
1
));
int
k
=
static_cast
<
int
>
(
b_tensors
.
size
(
2
))
*
PackFactor
;
// logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK
(
k
%
256
==
0
,
"logical k must be divisible by 256"
);
TORCH_CHECK
(
n
%
16
==
0
,
"n must be divisible by 16"
);
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr
size_t
layout_width
=
sizeof
(
LayoutB_Reordered
)
/
sizeof
(
int32_t
);
torch
::
Tensor
b_tensors_packed
=
torch
::
empty_like
(
b_tensors
);
int
num_experts
=
static_cast
<
int
>
(
b_tensors
.
size
(
0
));
auto
b_ptr
=
static_cast
<
QuantType
const
*>
(
b_tensors
.
const_data_ptr
());
auto
b_packed_ptr
=
static_cast
<
QuantType
*>
(
b_tensors_packed
.
data_ptr
());
// multiply by ull so result does not overflow int32
size_t
num_int4_elems
=
1ull
*
num_experts
*
n
*
k
;
bool
ok
=
vllm
::
cutlass_w4a8_utils
::
unified_encode_int4b
(
b_ptr
,
b_packed_ptr
,
num_int4_elems
);
TORCH_CHECK
(
ok
,
"unified_encode_int4b failed"
);
// construct the layout once; assumes each expert has the same layout
using
LayoutType
=
LayoutB_Reordered
;
std
::
vector
<
LayoutType
>
layout_B_reordered_host
(
num_experts
);
auto
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
n
,
k
,
Int
<
1
>
{}});
auto
shape_B
=
cute
::
make_shape
(
n
,
k
,
Int
<
1
>
{});
auto
layout_B
=
make_layout
(
shape_B
,
stride_B
);
LayoutType
layout_B_reordered
=
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
// reorder weights for each expert
for
(
int
i
=
0
;
i
<
num_experts
;
i
++
)
{
// since the storage type of int4b is 1 byte but one element is 4 bits
// we need to adjust the offset
int64_t
offset
=
1ull
*
i
*
n
*
k
*
cutlass
::
sizeof_bits
<
QuantType
>::
value
/
8
;
cutlass
::
reorder_tensor
(
b_packed_ptr
+
offset
,
layout_B
,
layout_B_reordered
);
}
// save the packed layout to torch tensor so we can re-use it
auto
cpu_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCPU
);
torch
::
Tensor
layout_cpu
=
torch
::
empty
({
num_experts
,
layout_width
},
cpu_opts
);
int32_t
*
layout_data
=
layout_cpu
.
data_ptr
<
int32_t
>
();
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
std
::
memcpy
(
layout_data
+
i
*
layout_width
,
// dst (int32*)
&
layout_B_reordered
,
// src (LayoutType*)
sizeof
(
LayoutType
));
// number of bytes
}
torch
::
Tensor
packed_layout
=
layout_cpu
.
to
(
b_tensors
.
device
(),
/*non_blocking=*/
false
);
return
{
b_tensors_packed
,
packed_layout
};
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"cutlass_w4a8_moe_mm"
,
&
mm
);
m
.
impl
(
"cutlass_encode_and_reorder_int4b_grouped"
,
&
encode_and_reorder_int4b
);
}
}
// namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////
\ No newline at end of file
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
View file @
f6227c22
...
...
@@ -7,6 +7,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
...
...
@@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return
packed_scales
;
}
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__
uint8_t
kNibbleLUT
[
256
];
__global__
void
unified_encode_int4b_device
(
const
uint8_t
*
in
,
uint8_t
*
out
,
size_t
nbytes
)
{
constexpr
size_t
V
=
sizeof
(
uint4
);
// 16 bytes
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
nthreads
=
size_t
(
gridDim
.
x
)
*
blockDim
.
x
;
const
size_t
nvec
=
nbytes
/
V
;
// 1-D grid-stride loop over 16-byte chunks
for
(
size_t
vec
=
tid
;
vec
<
nvec
;
vec
+=
nthreads
)
{
uint4
v
=
reinterpret_cast
<
const
uint4
*>
(
in
)[
vec
];
uint8_t
*
b
=
reinterpret_cast
<
uint8_t
*>
(
&
v
);
#pragma unroll
for
(
int
i
=
0
;
i
<
int
(
V
);
++
i
)
b
[
i
]
=
kNibbleLUT
[
b
[
i
]];
reinterpret_cast
<
uint4
*>
(
out
)[
vec
]
=
v
;
}
}
static
bool
upload_lut
()
{
std
::
array
<
uint8_t
,
256
>
lut
{};
auto
map_nib
=
[](
uint8_t
v
)
->
uint8_t
{
// 1..7 -> (8 - v); keep 0 and 8..15
return
(
v
==
0
||
(
v
&
0x8
))
?
v
:
uint8_t
(
8
-
v
);
};
for
(
int
b
=
0
;
b
<
256
;
++
b
)
{
uint8_t
lo
=
b
&
0xF
;
uint8_t
hi
=
(
b
>>
4
)
&
0xF
;
lut
[
b
]
=
uint8_t
((
map_nib
(
hi
)
<<
4
)
|
map_nib
(
lo
));
}
cudaError_t
e
=
cudaMemcpyToSymbol
(
kNibbleLUT
,
lut
.
data
(),
lut
.
size
(),
/*offset=*/
0
,
cudaMemcpyHostToDevice
);
return
(
e
==
cudaSuccess
);
}
static
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
)
{
// Build/upload LUT
if
(
!
upload_lut
())
return
false
;
static_assert
(
sizeof
(
typename
cutlass
::
int4b_t
::
Storage
)
==
1
,
"int4 storage must be 1 byte"
);
const
size_t
nbytes
=
num_int4_elems
>>
1
;
auto
*
in_bytes
=
reinterpret_cast
<
uint8_t
const
*>
(
in
);
auto
*
out_bytes
=
reinterpret_cast
<
uint8_t
*>
(
out
);
// kernel launch params
constexpr
int
block
=
256
;
const
size_t
nvec
=
nbytes
/
sizeof
(
uint4
);
// # of 16B vectors
int
grid
=
int
((
nvec
+
block
-
1
)
/
block
);
if
(
grid
==
0
)
grid
=
1
;
// ensure we still cover the tail in the kernel
unified_encode_int4b_device
<<<
grid
,
block
>>>
(
in_bytes
,
out_bytes
,
nbytes
);
cudaError_t
err
=
cudaGetLastError
();
return
(
err
==
cudaSuccess
);
}
torch
::
Tensor
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
B
)
{
TORCH_CHECK
(
B
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
B
.
dim
()
==
2
);
...
...
@@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
LayoutB_Reordered
layout_B_reordered
=
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
bool
ok
=
vllm
::
cutlass_w4a8
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
bool
ok
=
vllm
::
cutlass_w4a8_utils
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
TORCH_CHECK
(
ok
,
"unified_encode_int4b failed"
);
cutlass
::
reorder_tensor
(
B_packed_ptr
,
layout_B
,
layout_B_reordered
);
...
...
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
0 → 100644
View file @
f6227c22
#include "w4a8_utils.cuh"
#include <array>
#include <cuda_runtime.h>
#include <cstdio>
namespace
vllm
::
cutlass_w4a8_utils
{
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__
uint8_t
kNibbleLUT
[
256
];
__global__
void
unified_encode_int4b_device
(
const
uint8_t
*
in
,
uint8_t
*
out
,
size_t
nbytes
)
{
constexpr
size_t
V
=
sizeof
(
uint4
);
// 16 bytes
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
nthreads
=
size_t
(
gridDim
.
x
)
*
blockDim
.
x
;
const
size_t
nvec
=
nbytes
/
V
;
// 1-D grid-stride loop over 16-byte chunks
for
(
size_t
vec
=
tid
;
vec
<
nvec
;
vec
+=
nthreads
)
{
uint4
v
=
reinterpret_cast
<
const
uint4
*>
(
in
)[
vec
];
uint8_t
*
b
=
reinterpret_cast
<
uint8_t
*>
(
&
v
);
#pragma unroll
for
(
int
i
=
0
;
i
<
int
(
V
);
++
i
)
b
[
i
]
=
kNibbleLUT
[
b
[
i
]];
reinterpret_cast
<
uint4
*>
(
out
)[
vec
]
=
v
;
}
}
static
bool
upload_lut
()
{
std
::
array
<
uint8_t
,
256
>
lut
{};
auto
map_nib
=
[](
uint8_t
v
)
->
uint8_t
{
// 1..7 -> (8 - v); keep 0 and 8..15
return
(
v
==
0
||
(
v
&
0x8
))
?
v
:
uint8_t
(
8
-
v
);
};
for
(
int
b
=
0
;
b
<
256
;
++
b
)
{
uint8_t
lo
=
b
&
0xF
;
uint8_t
hi
=
(
b
>>
4
)
&
0xF
;
lut
[
b
]
=
uint8_t
((
map_nib
(
hi
)
<<
4
)
|
map_nib
(
lo
));
}
cudaError_t
e
=
cudaMemcpyToSymbol
(
kNibbleLUT
,
lut
.
data
(),
lut
.
size
(),
/*offset=*/
0
,
cudaMemcpyHostToDevice
);
return
(
e
==
cudaSuccess
);
}
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
)
{
// Build/upload LUT
if
(
!
upload_lut
())
return
false
;
static_assert
(
sizeof
(
typename
cutlass
::
int4b_t
::
Storage
)
==
1
,
"int4 storage must be 1 byte"
);
const
size_t
nbytes
=
num_int4_elems
>>
1
;
auto
*
in_bytes
=
reinterpret_cast
<
uint8_t
const
*>
(
in
);
auto
*
out_bytes
=
reinterpret_cast
<
uint8_t
*>
(
out
);
// kernel launch params
constexpr
int
block
=
256
;
const
size_t
nvec
=
nbytes
/
sizeof
(
uint4
);
// # of 16B vectors
int
grid
=
int
((
nvec
+
block
-
1
)
/
block
);
if
(
grid
==
0
)
grid
=
1
;
// ensure we still cover the tail in the kernel
unified_encode_int4b_device
<<<
grid
,
block
>>>
(
in_bytes
,
out_bytes
,
nbytes
);
// launch errors
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"unified_encode_int4b_device launch error: %s (%d)
\n
"
,
cudaGetErrorString
(
err
),
err
);
return
false
;
}
// runtime errors
err
=
cudaDeviceSynchronize
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"unified_encode_int4b_device runtime error: %s (%d)
\n
"
,
cudaGetErrorString
(
err
),
err
);
return
false
;
}
return
true
;
}
}
// namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
0 → 100644
View file @
f6227c22
#pragma once
#include <cstddef>
#include "cutlass/numeric_types.h"
namespace
vllm
::
cutlass_w4a8_utils
{
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
);
}
// namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
View file @
f6227c22
...
...
@@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
// Swap-AB should be disabled for FP4 path
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
bool
may_swap_ab
=
force_swap_ab
.
value_or
((
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
));
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
...
...
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
View file @
f6227c22
...
...
@@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller(
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
...
...
@@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data(
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
)
{
int32_t
version_num
=
get_sm_version_num
();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
);
blockscale_offsets
,
force_swap_ab
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
csrc/torch_bindings.cpp
View file @
f6227c22
...
...
@@ -350,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"
);
// conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops
.
def
(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()"
);
ops
.
def
(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)"
);
// conditionally compiled so impl registration is in source file
#endif
// Dequantization for GGML.
...
...
@@ -466,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()"
);
" Tensor? blockscale_offsets, "
" bool? force_swap_ab) -> ()"
);
ops
.
impl
(
"get_cutlass_moe_mm_problem_sizes"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_problem_sizes
);
...
...
tests/kernels/quantization/test_cutlass_w4a8.py
View file @
f6227c22
...
...
@@ -12,8 +12,11 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_packed_uint4b8_to_signed_int4_inplace
,
pack_cols
,
pack_rows
,
quantize_weights
,
unpack_quantized_values_into_int32
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
...
@@ -167,8 +170,7 @@ def create_test_tensors(
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s
=
torch
.
randn
((
m
,),
device
=
"cuda"
,
dtype
=
types
.
token_scale_type
)
# weights are already per-group quantized, use placeholder here
w_ch_s
=
torch
.
ones
((
n
,),
device
=
"cuda"
,
dtype
=
types
.
channel_scale_type
)
w_ch_s
=
torch
.
randn
((
n
,),
device
=
"cuda"
,
dtype
=
types
.
channel_scale_type
)
return
Tensors
(
w_ref
=
w_ref
,
...
...
@@ -211,7 +213,7 @@ def mm_test_helper(
print
(
output_ref
)
torch
.
testing
.
assert_close
(
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
1e-
3
,
atol
=
1e-
3
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
1e-
2
,
atol
=
1e-
2
)
...
...
@@ -257,7 +259,7 @@ def test_w4a8_cuda_graph():
)
w_tok_s
=
torch
.
randn
((
m
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w_ch_s
=
torch
.
ones
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w_ch_s
=
torch
.
randn
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Construct a trivial model with a single layer that calls the kernel
model
=
W4A8Layer
(
...
...
@@ -287,4 +289,38 @@ def test_w4a8_cuda_graph():
output
.
zero_
()
g
.
replay
()
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"CUTLASS W4A8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
)
def
test_convert_packed_uint4b8_to_signed_int4_inplace
(
shape
):
"""
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
The CUTLASS kernels expect signed int4 packed to int32.
This tests checks that the runtime int4b8 -> signed int4 conversion
matches the offline conversion step exactly.
"""
_
,
N
,
K
=
shape
# random weights packed to int32
t
=
torch
.
randint
(
low
=
torch
.
iinfo
(
torch
.
int32
).
min
,
high
=
torch
.
iinfo
(
torch
.
int32
).
max
+
1
,
size
=
(
N
,
K
//
8
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
# compute reference
unpacked
=
unpack_quantized_values_into_int32
(
t
.
clone
(),
scalar_types
.
uint4b8
,
packed_dim
=
1
)
unpacked
=
unpacked
-
8
# int4b8 -> signed int4
ref
=
pack_cols
(
unpacked
&
0x0F
,
4
,
*
unpacked
.
shape
)
out
=
convert_packed_uint4b8_to_signed_int4_inplace
(
t
.
clone
())
assert
torch
.
equal
(
ref
,
out
)
assert
not
torch
.
equal
(
ref
,
t
)
tests/kernels/quantization/test_cutlass_w4a8_moe.py
0 → 100644
View file @
f6227c22
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
"""
import
random
from
dataclasses
import
dataclass
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
IS_SUPPORTED_BY_GPU
=
current_platform
.
get_device_capability
()[
0
]
>=
9
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
cutlass_quantize
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
torch
.
dtype
|
None
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
):
"""
Quantize weights into W4 and compute reference dequantized weights.
Encoding/reordering of weights and packing of scales is deferred
until after all experts are combined.
"""
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
)
# Since scales are later cast to fp8, recompute w_ref in atype here.
w_ref
=
(
w_q
.
to
(
torch
.
float32
)
*
w_s
.
to
(
atype
).
to
(
torch
.
float32
).
repeat_interleave
(
group_size
,
dim
=
0
)
).
to
(
atype
)
# Bit mask prevents sign extension of int4 when packing.
w_q
=
pack_rows
(
w_q
&
0x0F
,
wtype
.
size_bits
,
*
w_q
.
shape
)
# Make weights row-major (N, K).
w_q
=
w_q
.
t
().
contiguous
()
return
w_ref
,
w_q
,
w_s
.
to
(
atype
),
w_zp
def
cutlass_preprocess
(
w_q_experts
:
list
[
torch
.
Tensor
],
w_s_experts
:
list
[
torch
.
Tensor
]
):
"""
Reorder/encode expert weights and pack scales.
Returns:
w_q_packed: Packed/encoded int4 weights for all experts.
w_s_packed: Packed fp8 scales for all experts.
packed_layout: Layout/stride metadata for grouped GEMM.
"""
w_s_packed
=
ops
.
cutlass_pack_scale_fp8
(
torch
.
stack
(
w_s_experts
))
w_q_packed
,
packed_layout
=
ops
.
cutlass_encode_and_reorder_int4b_grouped
(
torch
.
stack
(
w_q_experts
)
)
# expects dim 3
return
w_q_packed
,
w_s_packed
,
packed_layout
GROUP_SIZE
=
128
# (num_experts, N, K)
TEST_SHAPES
=
[
(
8
,
512
,
2048
),
(
8
,
2048
,
2048
),
(
64
,
512
,
1024
),
(
64
,
2048
,
2048
),
(
4
,
2048
,
768
),
(
8
,
768
,
2048
),
(
64
,
1536
,
2048
),
(
128
,
8192
,
4096
),
# test overflow int32
]
ALIGNMENT
=
16
# torch._scaled_mm alignment for M, needed for reference check
@
dataclass
class
MoETestSetup
:
num_experts
:
int
K
:
int
N
:
int
Ms
:
list
[
int
]
M_full
:
int
a
:
torch
.
Tensor
a_ref
:
torch
.
Tensor
a_strides
:
torch
.
Tensor
out
:
torch
.
Tensor
c_strides
:
torch
.
Tensor
per_tok_scales
:
torch
.
Tensor
per_chan_scales
:
torch
.
Tensor
w_refs
:
list
[
torch
.
Tensor
]
w_q_packed
:
torch
.
Tensor
w_s_packed
:
torch
.
Tensor
problem_sizes
:
torch
.
Tensor
expert_offsets
:
torch
.
Tensor
b_strides
:
torch
.
Tensor
group_scale_strides
:
torch
.
Tensor
def
make_moe_test_setup
(
num_experts
:
int
,
K
:
int
,
N
:
int
,
*
,
alignment
:
int
=
ALIGNMENT
,
max_blocks
:
int
=
64
,
device
:
str
=
"cuda"
,
random_zero
:
bool
=
False
,
)
->
MoETestSetup
:
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
assert
K
%
GROUP_SIZE
==
0
# Token counts per expert (multiples of `alignment`).
Ms
=
[
alignment
*
random
.
randint
(
1
,
max_blocks
)
for
_
in
range
(
num_experts
)]
# set random experts to 0 tokens
if
random_zero
and
num_experts
>
1
:
num_zero
=
max
(
1
,
num_experts
//
8
)
zero_indices
=
random
.
sample
(
range
(
num_experts
),
k
=
num_zero
)
for
idx
in
zero_indices
:
Ms
[
idx
]
=
0
M_full
=
sum
(
Ms
)
assert
M_full
>
0
# Activations.
a
=
to_fp8
(
torch
.
randn
((
M_full
,
K
),
device
=
device
))
a_ref
=
a
.
to
(
torch
.
float32
)
a_strides
=
torch
.
full
((
num_experts
,),
K
,
dtype
=
torch
.
int64
,
device
=
device
)
# Output buffer.
out
=
torch
.
empty
((
M_full
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
c_strides
=
torch
.
full
((
num_experts
,),
N
,
dtype
=
torch
.
int64
,
device
=
device
)
# Channel/token scales.
per_tok_scales
=
torch
.
randn
((
M_full
,
1
),
dtype
=
torch
.
float32
,
device
=
device
)
per_chan_scales
=
torch
.
randn
(
(
num_experts
,
N
,
1
),
dtype
=
torch
.
float32
,
device
=
device
)
# Expert weights and scales.
wtype
=
scalar_types
.
int4
atype
=
stype
=
torch
.
float8_e4m3fn
w_refs
,
w_qs
,
w_ss
=
[],
[],
[]
for
_
in
range
(
num_experts
):
b
=
to_fp8
(
torch
.
randn
((
K
,
N
),
device
=
device
))
w_ref
,
w_q
,
w_s
,
_
=
cutlass_quantize
(
atype
,
b
.
to
(
torch
.
float16
),
wtype
,
stype
,
GROUP_SIZE
,
zero_points
=
False
)
w_refs
.
append
(
w_ref
)
w_qs
.
append
(
w_q
)
w_ss
.
append
(
w_s
)
w_q_packed
,
w_s_packed
,
packed_layout
=
cutlass_preprocess
(
w_qs
,
w_ss
)
problem_sizes
=
torch
.
tensor
(
[[
N
,
M
,
K
]
for
M
in
Ms
],
dtype
=
torch
.
int32
,
device
=
device
)
expert_offsets
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int64
),
torch
.
cumsum
(
torch
.
tensor
(
Ms
,
dtype
=
torch
.
int64
),
dim
=
0
)[:
-
1
],
]
).
to
(
device
=
device
)
# B strides and group scale strides.
b_strides
=
packed_layout
group_scale_strides
=
torch
.
zeros
(
(
num_experts
,
2
),
dtype
=
torch
.
int64
,
device
=
device
)
group_scale_strides
[:,
0
]
=
N
return
MoETestSetup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
Ms
=
Ms
,
M_full
=
M_full
,
a
=
a
,
a_ref
=
a_ref
,
a_strides
=
a_strides
,
out
=
out
,
c_strides
=
c_strides
,
per_tok_scales
=
per_tok_scales
,
per_chan_scales
=
per_chan_scales
,
w_refs
=
w_refs
,
w_q_packed
=
w_q_packed
,
w_s_packed
=
w_s_packed
,
problem_sizes
=
problem_sizes
,
expert_offsets
=
expert_offsets
,
b_strides
=
b_strides
,
group_scale_strides
=
group_scale_strides
,
)
def
compute_moe_reference_output
(
setup
:
MoETestSetup
)
->
torch
.
Tensor
:
"""Compute reference output using torch._scaled_mm per expert."""
out_ref
=
torch
.
empty_like
(
setup
.
out
)
ends
=
torch
.
cumsum
(
torch
.
tensor
(
setup
.
Ms
),
0
).
tolist
()
starts
=
setup
.
expert_offsets
.
cpu
().
tolist
()
for
i
in
range
(
setup
.
num_experts
):
start
,
end
=
starts
[
i
],
ends
[
i
]
if
start
==
end
:
continue
out_ref_i
=
torch
.
_scaled_mm
(
setup
.
a_ref
[
start
:
end
].
to
(
torch
.
float8_e4m3fn
),
setup
.
w_refs
[
i
].
to
(
torch
.
float8_e4m3fn
).
t
().
contiguous
().
t
(),
setup
.
per_tok_scales
[
start
:
end
],
# (M, 1)
setup
.
per_chan_scales
[
i
].
reshape
(
1
,
-
1
),
# (1, N)
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
,
)
out_ref
[
start
:
end
]
=
out_ref_i
return
out_ref
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"W4A8 Grouped GEMM is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
TEST_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"random_zero"
,
[
True
,
False
])
def
test_cutlass_w4a8_moe_mm_end_to_end
(
shape
,
random_zero
):
num_experts
,
N
,
K
=
shape
current_platform
.
seed_everything
(
42
)
setup
=
make_moe_test_setup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
max_blocks
=
64
,
random_zero
=
random_zero
)
ops
.
cutlass_w4a8_moe_mm
(
setup
.
out
,
setup
.
a
,
setup
.
w_q_packed
,
setup
.
per_tok_scales
,
setup
.
per_chan_scales
,
setup
.
w_s_packed
,
GROUP_SIZE
,
setup
.
expert_offsets
,
setup
.
problem_sizes
,
setup
.
a_strides
,
setup
.
b_strides
,
setup
.
c_strides
,
setup
.
group_scale_strides
,
)
torch
.
cuda
.
synchronize
()
out_ref
=
compute_moe_reference_output
(
setup
)
torch
.
testing
.
assert_close
(
setup
.
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
class
W4A8MoELayer
(
torch
.
nn
.
Module
):
"""
Minimal wrapper module to test cuda graphs
"""
def
__init__
(
self
,
setup
:
MoETestSetup
):
super
().
__init__
()
self
.
setup
=
setup
def
forward
(
self
,
a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
s
=
self
.
setup
ops
.
cutlass_w4a8_moe_mm
(
s
.
out
,
a
,
s
.
w_q_packed
,
s
.
per_tok_scales
,
s
.
per_chan_scales
,
s
.
w_s_packed
,
GROUP_SIZE
,
s
.
expert_offsets
,
s
.
problem_sizes
,
s
.
a_strides
,
s
.
b_strides
,
s
.
c_strides
,
s
.
group_scale_strides
,
)
return
s
.
out
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"W4A8 Grouped GEMM is not supported on this GPU type."
,
)
def
test_cutlass_w4a8_moe_mm_cuda_graph
():
current_platform
.
seed_everything
(
42
)
# Fixed config for CUDA graph test (single parameter point).
num_experts
=
8
K
=
512
N
=
2048
setup
=
make_moe_test_setup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
max_blocks
=
32
,
)
# Construct model that calls the grouped GEMM kernel.
model
=
W4A8MoELayer
(
setup
)
# Build reference output once.
out_ref
=
compute_moe_reference_output
(
setup
)
# Capture and run the model in a CUDA graph.
a_static
=
setup
.
a
.
clone
()
# static input tensor for graph replay
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
out_static
=
model
(
a_static
)
out_static
.
zero_
()
g
.
replay
()
torch
.
testing
.
assert_close
(
out_static
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
vllm/_custom_ops.py
View file @
f6227c22
...
...
@@ -695,6 +695,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def
cutlass_encode_and_reorder_int4b_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::cutlass_encode_and_reorder_int4b_grouped"
)
def
cutlass_encode_and_reorder_int4b_grouped_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
...
...
@@ -1058,6 +1062,7 @@ def get_cutlass_moe_mm_problem_sizes(
n
:
int
,
k
:
int
,
blockscale_offsets
:
torch
.
Tensor
|
None
=
None
,
force_swap_ab
:
bool
|
None
=
None
,
):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
...
...
@@ -1067,9 +1072,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
,
force_swap_ab
,
)
...
...
@@ -1457,6 +1473,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b
(
b
)
def
cutlass_w4a8_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_group_scales
:
torch
.
Tensor
,
b_group_size
:
int
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
a_strides
:
torch
.
Tensor
,
b_strides
:
torch
.
Tensor
,
c_strides
:
torch
.
Tensor
,
group_scale_strides
:
torch
.
Tensor
,
maybe_schedule
:
str
|
None
=
None
,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return
torch
.
ops
.
_C
.
cutlass_w4a8_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
,
maybe_schedule
,
)
def
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
register_fake
(
"_C::permute_cols"
)
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
f6227c22
...
...
@@ -63,8 +63,10 @@ if HAS_TRITON:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassBatchedExpertsFp8
,
CutlassExpertsFp8
,
CutlassExpertsW4A8Fp8
,
cutlass_moe_fp4
,
cutlass_moe_fp8
,
cutlass_moe_w4a8_fp8
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
...
...
@@ -88,8 +90,10 @@ if HAS_TRITON:
"grouped_topk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_w4a8_fp8"
,
"CutlassExpertsFp8"
,
"CutlassBatchedExpertsFp8"
,
"CutlassExpertsW4A8Fp8"
,
"TritonExperts"
,
"BatchedTritonExperts"
,
"DeepGemmExperts"
,
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
f6227c22
...
...
@@ -143,6 +143,7 @@ class FusedMoEQuantDesc:
scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
,
None
]
=
None
# Quantization alphas or gscales, used for nvfp4 types.
# W4A8 FP8: used for per-channel scales
# TODO(bnell): put some of these in subclasses
alpha_or_gscale
:
torch
.
Tensor
|
None
=
None
...
...
@@ -442,7 +443,9 @@ class FusedMoEQuantConfig:
- a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
per-channel scales for w1 (for W4A8 FP8).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
per-channel scales for w2 (for W4A8 FP8).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
- w1_bias: Optional biases for w1 (GPT OSS Triton).
...
...
@@ -461,6 +464,7 @@ class FusedMoEQuantConfig:
"mxfp4"
,
"mxfp6_e3m2"
,
"mxfp6_e2m3"
,
"int4"
,
}
if
weight_dtype
is
None
:
...
...
@@ -671,6 +675,31 @@ def int8_w8a16_moe_quant_config(
)
def
int4_w4afp8_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
g1_alphas
:
torch
.
Tensor
,
g2_alphas
:
torch
.
Tensor
,
per_act_token_quant
:
bool
=
False
,
per_out_ch_quant
:
bool
=
False
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
FusedMoEQuantConfig
:
"""
Construct a quant config for fp8 activations and int4 weights.
"""
return
FusedMoEQuantConfig
.
make
(
torch
.
float8_e4m3fn
,
# quant dtype for activations
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
g1_alphas
=
g1_alphas
,
g2_alphas
=
g2_alphas
,
per_act_token_quant
=
per_act_token_quant
,
per_out_ch_quant
=
per_out_ch_quant
,
block_shape
=
block_shape
,
weight_dtype
=
"int4"
,
# weight dtype for weights
)
def
biased_moe_quant_config
(
w1_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
f6227c22
...
...
@@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts(
return
(
c2
[
c_map
].
view
(
m
,
topk
,
k
)
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
).
sum
(
dim
=
1
)
# W4A8
def
run_cutlass_moe_w4a8_fp8
(
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation_callable
:
Callable
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
w1_scale
:
torch
.
Tensor
|
None
,
w2_scale
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
w1_chan_scale
:
torch
.
Tensor
,
w2_chan_scale
:
torch
.
Tensor
,
a_strides1
:
torch
.
Tensor
,
a_strides2
:
torch
.
Tensor
,
b_strides1
:
torch
.
Tensor
,
b_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
s_strides1
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
torch
.
Tensor
|
None
,
out_dtype
:
torch
.
dtype
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_batched_format
:
bool
,
topk_weights
:
torch
.
Tensor
|
None
,
group_size
:
int
,
):
a1q
=
hidden_states
M
=
a1q
.
size
(
0
)
local_E
=
w1
.
size
(
0
)
device
=
a1q
.
device
_
,
K
,
N_packed
=
w2
.
shape
N
=
N_packed
*
8
# logical N, pack 8 int4 into 1 int32
assert
per_act_token
,
"W4A8 must use per-token scales"
assert
per_out_ch
,
"W4A8 must use per-channel scales"
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
assert
w1_scale
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_scale
.
dtype
==
torch
.
float8_e4m3fn
assert
w1
.
dtype
==
torch
.
int32
assert
w2
.
dtype
==
torch
.
int32
assert
w1_chan_scale
.
dtype
==
torch
.
float32
assert
w2_chan_scale
.
dtype
==
torch
.
float32
assert
w1
.
size
(
0
)
==
w2
.
size
(
0
),
"Weights expert number mismatch"
assert
a1q_scale
is
not
None
assert
a2_scale
is
None
assert
out_dtype
in
[
torch
.
bfloat16
],
f
"Invalid output dtype:
{
out_dtype
}
"
if
expert_map
is
not
None
:
assert
expert_num_tokens
is
None
assert
not
use_batched_format
,
"batched format not supported yet"
assert
group_size
==
128
,
f
"Only group size 128 supported but got
{
group_size
=
}
"
assert
global_num_experts
!=
-
1
assert
w1
.
size
(
2
)
*
8
==
K
,
(
f
"w1 hidden size mismatch: got
{
w1
.
size
(
2
)
*
8
}
, expected
{
K
=
}
"
)
# Translate info from expert_map to topk_ids
if
expert_map
is
not
None
:
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids
]
!=
-
1
,
expert_map
[
topk_ids
],
-
1
)
else
:
local_topk_ids
=
topk_ids
topk
=
local_topk_ids
.
size
(
1
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
K
))
mm1_out
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
act_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
# original workspace are based on input hidden_states dtype (bf16)
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
N
)
)
mm2_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
K
))
problem_sizes1
=
torch
.
empty
(
(
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
(
(
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
num_expert
=
global_num_experts
if
expert_map
is
None
else
expert_map
.
size
(
0
)
# permuted a1q reuses workspace2
a1q
,
a1q_scale
,
expert_offsets
,
inv_perm
,
_
=
moe_permute
(
a1q
,
a1q_scale
,
topk_ids
,
num_expert
,
local_E
,
expert_map
,
permuted_hidden_states
=
a1q_perm
,
)
expert_offsets
=
expert_offsets
[:
-
1
]
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops
.
get_cutlass_moe_mm_problem_sizes
(
local_topk_ids
,
problem_sizes1
,
problem_sizes2
,
global_num_experts
,
N
,
K
,
force_swap_ab
=
True
,
)
ops
.
cutlass_w4a8_moe_mm
(
mm1_out
,
a1q
,
w1
,
a1q_scale
,
w1_chan_scale
,
w1_scale
,
group_size
,
expert_offsets
,
problem_sizes1
,
a_strides1
,
b_strides1
,
c_strides1
,
s_strides1
,
)
activation_callable
(
act_out
,
mm1_out
)
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
act_out
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
,
output
=
quant_out
)
if
expert_map
is
not
None
:
mm2_out
.
fill_
(
0
)
ops
.
cutlass_w4a8_moe_mm
(
mm2_out
,
a2q
,
w2
,
a2q_scale
,
w2_chan_scale
,
w2_scale
,
group_size
,
expert_offsets
,
problem_sizes2
,
a_strides2
,
b_strides2
,
c_strides2
,
s_strides2
,
)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute
(
out
=
output
,
permuted_hidden_states
=
mm2_out
,
topk_weights
=
topk_weights
,
inv_permuted_idx
=
inv_perm
,
)
class
CutlassExpertsW4A8Fp8
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
out_dtype
:
torch
.
dtype
|
None
,
a_strides1
:
torch
.
Tensor
,
a_strides2
:
torch
.
Tensor
,
b_strides1
:
torch
.
Tensor
,
b_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
s_strides1
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
group_size
:
int
,
):
super
().
__init__
(
quant_config
)
self
.
out_dtype
=
out_dtype
self
.
a_strides1
=
a_strides1
self
.
a_strides2
=
a_strides2
self
.
b_strides1
=
b_strides1
self
.
b_strides2
=
b_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
self
.
s_strides1
=
s_strides1
self
.
s_strides2
=
s_strides2
self
.
group_size
=
group_size
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
|
None
,
workspace2
:
torch
.
Tensor
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
assert
self
.
w1_zp
is
None
,
"w1_zp is not supported in CUTLASS MoE"
assert
self
.
w2_zp
is
None
,
"w2_zp is not supported in CUTLASS MoE"
expert_num_tokens
=
None
activation_callable
=
lambda
o
,
i
:
self
.
activation
(
activation
,
o
,
i
)
use_batched_format
=
(
self
.
activation_formats
[
0
]
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
)
assert
not
use_batched_format
,
"batched format not supported"
in_dtype
=
hidden_states
.
dtype
run_cutlass_moe_w4a8_fp8
(
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
global_num_experts
,
expert_map
,
self
.
w1_scale
,
self
.
w2_scale
,
a1q_scale
,
a2_scale
,
self
.
g1_alphas
,
# per-channel scales
self
.
g2_alphas
,
# per-channel scales
self
.
a_strides1
,
self
.
a_strides2
,
self
.
b_strides1
,
self
.
b_strides2
,
self
.
c_strides1
,
self
.
c_strides2
,
self
.
s_strides1
,
self
.
s_strides2
,
workspace13
,
workspace2
,
expert_num_tokens
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
use_batched_format
,
topk_weights
,
self
.
group_size
,
)
def
cutlass_moe_w4a8_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
a_strides1
:
torch
.
Tensor
,
a_strides2
:
torch
.
Tensor
,
b_strides1
:
torch
.
Tensor
,
b_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
s_strides1
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
activation
:
str
=
"silu"
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
group_size
:
int
=
128
,
)
->
torch
.
Tensor
:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
mixed-dtype grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, 2*N, K // packed_factor]
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, K, N // packed_factor]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- a_strides1 (torch.Tensor): The input strides for the first gemm.
Shape: [num_experts]
- a_strides2 (torch.Tensor): The input strides for the second gemm.
Shape: [num_experts]
- b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
- group_size (int): The number of weights per scale factor
Returns:
- torch.Tensor: The bf16 output tensor after applying the MoE layer.
"""
assert
quant_config
is
not
None
num_experts
=
global_num_experts
if
global_num_experts
!=
-
1
else
w1_q
.
size
(
0
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
CutlassExpertsW4A8Fp8
(
out_dtype
=
a
.
dtype
,
a_strides1
=
a_strides1
,
a_strides2
=
a_strides2
,
b_strides1
=
b_strides1
,
b_strides2
=
b_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
s_strides1
=
s_strides1
,
s_strides2
=
s_strides2
,
quant_config
=
quant_config
,
group_size
=
group_size
,
),
)
return
fn
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
f6227c22
...
...
@@ -367,7 +367,7 @@ class FusedMoEPrepareAndFinalize(ABC):
class
FusedMoEPermuteExpertsUnpermute
(
ABC
):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
above.
"""
def
__init__
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
f6227c22
...
...
@@ -256,7 +256,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
format
is
not
None
else
is_activation_quantization_format
(
quant_format
)
)
#
TODO(czhu):
w4a8fp8 is in packed-quantized format
# w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations
=
quant_config
.
get
(
"input_activations"
)
if
act_quant_format
or
input_activations
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
f6227c22
...
...
@@ -33,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
int4_w4a16_moe_quant_config
,
int4_w4afp8_moe_quant_config
,
int8_w8a8_moe_quant_config
,
int8_w8a16_moe_quant_config
,
nvfp4_moe_quant_config
,
...
...
@@ -79,7 +80,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_moe_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
swizzle_blockscale
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_bf16_scales_to_fp8
,
convert_packed_uint4b8_to_signed_int4_inplace
,
swizzle_blockscale
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
...
...
@@ -204,6 +209,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return
CompressedTensorsW8A8Int8MoEMethod
(
weight_quant
,
input_quant
,
layer
.
moe_config
)
elif
quant_config
.
_is_fp8_w4a8_sm90
(
weight_quant
,
input_quant
):
logger
.
info_once
(
"Using CompressedTensorsW4A8Fp8MoEMethod"
)
return
CompressedTensorsW4A8Fp8MoEMethod
(
weight_quant
,
input_quant
,
layer
.
moe_config
)
elif
quant_config
.
_is_dynamic_token_w4a8_int
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A8Int8MoEMethod
(
weight_quant
,
input_quant
,
layer
.
moe_config
...
...
@@ -2428,3 +2438,331 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input
,
int
(
_act_kind
(
activation
)),
)
class
CompressedTensorsW4A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
input_quant
:
QuantizationArgs
,
moe
:
FusedMoEConfig
,
layer_name
:
str
|
None
=
None
,
):
super
().
__init__
(
moe
)
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
self
.
group_size
=
self
.
weight_quant
.
group_size
self
.
num_bits
=
self
.
weight_quant
.
num_bits
self
.
packed_factor
=
32
//
self
.
num_bits
assert
self
.
weight_quant
.
symmetric
,
(
"Only symmetric quantization is supported for W4A8 MoE"
)
assert
self
.
weight_quant
.
actorder
!=
"group"
assert
self
.
group_size
==
128
,
"Only group size 128 supported for W4A8 MoE"
self
.
disable_expert_map
=
False
self
.
layer_name
=
layer_name
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
self
.
quant_fp8
=
QuantFP8
(
static
=
False
,
group_shape
=
GroupShape
.
PER_TOKEN
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# requirement for CUTLASS reorder_tensor
assert
hidden_size
%
256
==
0
,
f
"
{
hidden_size
=
}
must be divisible by 256"
assert
intermediate_size_per_partition
%
256
==
0
,
(
f
"
{
intermediate_size_per_partition
=
}
must be divisible by 256"
)
# storage type, pack 8xint4 into int32
params_dtype
=
torch
.
int32
# WEIGHTS
w13_weight_packed
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
self
.
packed_factor
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_packed"
,
w13_weight_packed
)
set_weight_attrs
(
w13_weight_packed
,
extra_weight_attrs
)
w2_weight_packed
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
self
.
packed_factor
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_packed"
,
w2_weight_packed
)
set_weight_attrs
(
w2_weight_packed
,
extra_weight_attrs
)
# SCALES
# weight_scale refers to the group-wise scales
# they are initially loaded as bf16, we will convert to fp8
# after loading
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
self
.
group_size
,
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
self
.
group_size
,
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-GROUP quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# weight shapes
w2_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_shape"
,
w2_weight_shape
)
set_weight_attrs
(
w2_weight_shape
,
extra_weight_attrs
)
w13_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_shape"
,
w13_weight_shape
)
set_weight_attrs
(
w13_weight_shape
,
extra_weight_attrs
)
# don't use input scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
):
device
=
layer
.
w13_weight_packed
.
device
# STRIDES
# A, C
self
.
a_strides1_c_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
a_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides1
=
torch
.
full
(
(
layer
.
local_num_experts
,),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
# S (group-wise scales)
# sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it
self
.
s_strides1
=
torch
.
zeros
(
(
layer
.
local_num_experts
,
2
),
device
=
device
,
dtype
=
torch
.
int64
)
self
.
s_strides1
[:,
0
]
=
2
*
layer
.
intermediate_size_per_partition
self
.
s_strides2
=
torch
.
zeros
(
(
layer
.
local_num_experts
,
2
),
device
=
device
,
dtype
=
torch
.
int64
)
self
.
s_strides2
[:,
0
]
=
layer
.
hidden_size
# encode and reorder weight tensors, and get the layout to pass to
# the grouped gemm kernel. `b_strides1/2` specifies the entire layout
convert_packed_uint4b8_to_signed_int4_inplace
(
layer
.
w13_weight_packed
)
w13_weight_shuffled
,
self
.
b_strides1
=
(
ops
.
cutlass_encode_and_reorder_int4b_grouped
(
layer
.
w13_weight_packed
)
)
replace_parameter
(
layer
,
"w13_weight_packed"
,
w13_weight_shuffled
)
convert_packed_uint4b8_to_signed_int4_inplace
(
layer
.
w2_weight_packed
)
w2_weight_shuffled
,
self
.
b_strides2
=
(
ops
.
cutlass_encode_and_reorder_int4b_grouped
(
layer
.
w2_weight_packed
)
)
replace_parameter
(
layer
,
"w2_weight_packed"
,
w2_weight_shuffled
)
# convert bf16 scales to (fp8_scales, channel_scales)
w13_weight_scale
,
w13_weight_chan_scale
=
convert_bf16_scales_to_fp8
(
self
.
quant_fp8
,
layer
.
w13_weight_scale
)
w2_weight_scale
,
w2_weight_chan_scale
=
convert_bf16_scales_to_fp8
(
self
.
quant_fp8
,
layer
.
w2_weight_scale
)
# register channel scales
layer
.
register_parameter
(
"w13_weight_chan_scale"
,
torch
.
nn
.
Parameter
(
w13_weight_chan_scale
,
requires_grad
=
False
),
)
layer
.
register_parameter
(
"w2_weight_chan_scale"
,
torch
.
nn
.
Parameter
(
w2_weight_chan_scale
,
requires_grad
=
False
),
)
# The scales are stored as (E, N, K // 128) but the kernel expects
# (E, K // 128, N) in row-major format, so we need to permute the last 2 dims
# and make it contiguous
w13_weight_scale_packed
=
ops
.
cutlass_pack_scale_fp8
(
w13_weight_scale
.
permute
(
0
,
2
,
1
).
contiguous
()
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_weight_scale_packed
)
w2_weight_scale_packed
=
ops
.
cutlass_pack_scale_fp8
(
w2_weight_scale
.
permute
(
0
,
2
,
1
).
contiguous
()
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_weight_scale_packed
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
# Store quantization scales; both per-group and per-channel
# Note we haven't specified the group size here because
# the quant config logic assumes group-wise scaling
# and channel-wise scaling are exclusive.
return
int4_w4afp8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
# group scale
w2_scale
=
layer
.
w2_weight_scale
,
# group scale
g1_alphas
=
layer
.
w13_weight_chan_scale
,
g2_alphas
=
layer
.
w2_weight_chan_scale
,
per_act_token_quant
=
True
,
# always use dynamc per-token
per_out_ch_quant
=
True
,
# always use per-channel
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
assert
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
Standard
),
"BatchedExperts not supported"
from
vllm.model_executor.layers.fused_moe
import
CutlassExpertsW4A8Fp8
experts
:
FusedMoEPermuteExpertsUnpermute
logger
.
debug
(
"CutlassExpertsW4A8Fp8(%s)"
,
self
.
__class__
.
__name__
)
experts
=
CutlassExpertsW4A8Fp8
(
out_dtype
=
self
.
moe
.
in_dtype
,
a_strides1
=
self
.
a_strides1_c_strides2
,
a_strides2
=
self
.
a_strides2
,
b_strides1
=
self
.
b_strides1
,
b_strides2
=
self
.
b_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
a_strides1_c_strides2
,
s_strides1
=
self
.
s_strides1
,
s_strides2
=
self
.
s_strides2
,
quant_config
=
self
.
moe_quant_config
,
group_size
=
self
.
group_size
,
)
num_dispatchers
=
prepare_finalize
.
num_dispatchers
()
self
.
disable_expert_map
=
(
num_dispatchers
>
1
or
not
experts
.
supports_expert_map
()
)
return
experts
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert
self
.
moe_quant_config
is
not
None
topk_weights
,
topk_ids
,
_
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_w4a8_fp8
,
)
return
cutlass_moe_w4a8_fp8
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
topk_weights
,
topk_ids
,
quant_config
=
self
.
moe_quant_config
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
a_strides1
=
self
.
a_strides1_c_strides2
,
a_strides2
=
self
.
a_strides2
,
b_strides1
=
self
.
b_strides1
,
b_strides2
=
self
.
b_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
a_strides1_c_strides2
,
s_strides1
=
self
.
s_strides1
,
s_strides2
=
self
.
s_strides2
,
group_size
=
self
.
group_size
,
)
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
View file @
f6227c22
...
...
@@ -128,14 +128,15 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
),
)
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
# After loading, we will transform bf16 -> fp8 ->
# expand by 8x via `cutlass_pack_scale_fp8`
# and construct per-channel fp32 scales.
weight_scale_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
torch
.
float8_e4m3fn
,
dtype
=
params_dtype
,
),
}
...
...
@@ -152,17 +153,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
# per-channel scales
weight_chan_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
output_size_per_partition
,
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_chan_scale"
,
weight_chan_scale
)
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment