Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed1044ac
Unverified
Commit
ed1044ac
authored
Oct 30, 2025
by
AichenF
Committed by
GitHub
Oct 29, 2025
Browse files
support cutlass fp4 kernel in sm120 (#11737)
parent
d717e73e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
543 additions
and
51 deletions
+543
-51
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
+2
-2
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+6
-3
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
+5
-2
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
+27
-2
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+244
-10
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
+259
-32
No files found.
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
View file @
ed1044ac
...
@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
...
@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires >= sm100f.
// PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED ||
CUTLASS_ARCH_MMA_SM120A_ENABLED ||
\
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
...
@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
...
@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires >= sm100f.
// PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED ||
CUTLASS_ARCH_MMA_SM120A_ENABLED ||
\
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
View file @
ed1044ac
...
@@ -16,8 +16,11 @@ limitations under the License.
...
@@ -16,8 +16,11 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_quant_sm100a
(
void
scaled_fp4_quant_sm100a_sm120a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
void
scaled_fp4_experts_quant_sm100a
(
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
...
@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
return
scaled_fp4_quant_sm100a
_sm120a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
}
}
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
View file @
ed1044ac
...
@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() {
...
@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() {
return
multi_processor_count
;
// Return the cached value on subsequent calls
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
}
void
scaled_fp4_quant_sm100a
(
void
scaled_fp4_quant_sm100a_sm120a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
auto
sm_version
=
getSMVersion
();
auto
sm_version
=
getSMVersion
();
TORCH_CHECK
(
sm_version
>=
100
,
"fp4_quant is only supported on sm100+"
);
TORCH_CHECK
(
sm_version
>=
100
,
"fp4_quant is only supported on sm100+"
);
...
...
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
View file @
ed1044ac
...
@@ -16,13 +16,38 @@ limitations under the License.
...
@@ -16,13 +16,38 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
cutlass_scaled_fp4_mm_sm100a
(
void
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
torch
::
Tensor
const
&
alpha
);
// SM120 specific dispatch functions
void
cutlass_fp4_bf16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
);
void
cutlass_fp4_f16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
);
#endif
#endif
void
cutlass_scaled_fp4_mm
(
void
cutlass_scaled_fp4_mm
(
...
@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm(
...
@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm(
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
torch
::
Tensor
const
&
alpha
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
cutlass_scaled_fp4_mm_sm100a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
return
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel."
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel."
);
}
}
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
View file @
ed1044ac
...
@@ -17,6 +17,8 @@ limitations under the License.
...
@@ -17,6 +17,8 @@ limitations under the License.
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/all.h>
#include "utils.h"
// clang-format off
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
...
@@ -37,7 +39,20 @@ limitations under the License.
...
@@ -37,7 +39,20 @@ limitations under the License.
using
namespace
cute
;
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Helper function for next power of 2
inline
uint32_t
next_pow_2
(
uint32_t
x
)
{
if
(
x
==
0
)
return
1
;
x
--
;
x
|=
x
>>
1
;
x
|=
x
>>
2
;
x
|=
x
>>
4
;
x
|=
x
>>
8
;
x
|=
x
>>
16
;
return
x
+
1
;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \
defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Config(half_t/bfloat16_t) for M <= 128
// Config(half_t/bfloat16_t) for M <= 128
template
<
typename
T
>
template
<
typename
T
>
struct
KernelConfigM128
{
struct
KernelConfigM128
{
...
@@ -102,6 +117,19 @@ struct KernelConfigFp32 {
...
@@ -102,6 +117,19 @@ struct KernelConfigFp32 {
const
dim3
KernelConfigFp32
::
preferred_cluster
=
dim3
(
1
,
4
,
1
);
const
dim3
KernelConfigFp32
::
preferred_cluster
=
dim3
(
1
,
4
,
1
);
const
dim3
KernelConfigFp32
::
fallback_cluster
=
dim3
(
1
,
2
,
1
);
const
dim3
KernelConfigFp32
::
fallback_cluster
=
dim3
(
1
,
2
,
1
);
// SM120 specific configurations
struct
sm120_fp4_config_M256
{
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_128
>
;
};
struct
sm120_fp4_config_default
{
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
MmaTileShape
=
Shape
<
_256
,
_128
,
_128
>
;
using
PerSmTileShape_MNK
=
Shape
<
_256
,
_128
,
_128
>
;
};
template
<
typename
KernelConfig
>
template
<
typename
KernelConfig
>
struct
Fp4GemmSm100
{
struct
Fp4GemmSm100
{
using
Config
=
KernelConfig
;
// For generating args
using
Config
=
KernelConfig
;
// For generating args
...
@@ -183,6 +211,70 @@ struct Fp4GemmSm100 {
...
@@ -183,6 +211,70 @@ struct Fp4GemmSm100 {
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
};
};
// SM120 specific GEMM template
template
<
typename
Config
,
typename
OutType
>
struct
Fp4GemmSm120
{
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
32
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
32
;
using
ElementD
=
OutType
;
using
ElementC
=
OutType
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
MmaTileShape
=
typename
Config
::
MmaTileShape
;
using
ClusterShape
=
typename
Config
::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
Config
::
PerSmTileShape_MNK
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape_MNK
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutATag
,
AlignmentA
,
ElementB
,
LayoutBTag
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
template
<
typename
T
>
template
<
typename
T
>
typename
T
::
Gemm
::
Arguments
args_from_options
(
typename
T
::
Gemm
::
Arguments
args_from_options
(
at
::
Tensor
&
D
,
at
::
Tensor
&
D
,
...
@@ -267,6 +359,85 @@ void runGemm(
...
@@ -267,6 +359,85 @@ void runGemm(
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
}
// SM120 specific args_from_options function
template
<
typename
Gemm
>
typename
Gemm
::
Arguments
args_from_options_sm120
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
M
,
int
N
,
int
K
)
{
using
ElementA
=
typename
Gemm
::
ElementA
;
using
ElementB
=
typename
Gemm
::
ElementB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementSFA
=
cutlass
::
float_ue4m3_t
;
using
ElementSFB
=
cutlass
::
float_ue4m3_t
;
using
ElementCompute
=
float
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
Sm1xxBlkScaledConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
auto
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
{
M
,
K
,
1
});
auto
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
N
,
K
,
1
});
auto
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
{
M
,
N
,
1
});
auto
layout_SFA
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
M
,
N
,
K
,
1
));
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
M
,
N
,
K
,
1
));
typename
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
M
,
N
,
K
,
1
},
{
static_cast
<
ElementA
const
*>
(
A
.
data_ptr
()),
stride_A
,
static_cast
<
ElementB
const
*>
(
B
.
data_ptr
()),
stride_B
,
static_cast
<
ElementSFA
const
*>
(
A_sf
.
data_ptr
()),
layout_SFA
,
static_cast
<
ElementSFB
const
*>
(
B_sf
.
data_ptr
()),
layout_SFB
},
{{},
static_cast
<
ElementD
const
*>
(
D
.
data_ptr
()),
stride_D
,
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
return
arguments
;
}
// SM120 specific runGemm function
template
<
typename
Gemm
>
void
runGemmSm120
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
M
,
int
N
,
int
K
,
cudaStream_t
stream
)
{
Gemm
gemm
;
auto
arguments
=
args_from_options_sm120
<
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
M
,
N
,
K
);
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
// Dispatch function to select appropriate config based on M
// Dispatch function to select appropriate config based on M
template
<
typename
OutType
>
template
<
typename
OutType
>
void
cutlassFp4GemmDispatch
(
void
cutlassFp4GemmDispatch
(
...
@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>(
...
@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>(
runGemm
<
Fp4GemmSm100
<
KernelConfigFp32
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
runGemm
<
Fp4GemmSm100
<
KernelConfigFp32
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
// SM120 specific dispatch functions
void
cutlass_fp4_bf16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
)
{
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
mp2
<=
256
)
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_M256
,
cutlass
::
bfloat16_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_default
,
cutlass
::
bfloat16_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
void
cutlass_fp4_f16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
)
{
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
mp2
<=
256
)
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_M256
,
cutlass
::
half_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_default
,
cutlass
::
half_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
#else
#else
template
<
typename
T
>
template
<
typename
T
>
void
cutlassFp4GemmDispatch
(
void
cutlassFp4GemmDispatch
(
...
@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch(
...
@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch(
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."
);
"a CUTLASS 3.8 source directory to enable support."
);
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) ||
// defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
...
@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch(
...
@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch(
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
void
cutlass_scaled_fp4_mm_sm100a
(
void
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
B
,
...
@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a(
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
// Check SM version and dispatch accordingly
cutlassFp4GemmDispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
auto
sm_version
=
getSMVersion
();
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlassFp4GemmDispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
if
(
sm_version
==
120
)
{
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
// Use SM120 specific dispatch
cutlassFp4GemmDispatch
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
cutlass_fp4_f16_gemm_dispatch_sm120
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlass_fp4_bf16_gemm_dispatch_sm120
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm sm120 ("
,
out_dtype
,
")"
);
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
// Use SM100 dispatch for other architectures
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
cutlassFp4GemmDispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlassFp4GemmDispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
cutlassFp4GemmDispatch
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
}
}
}
}
}
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
View file @
ed1044ac
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/tensor_view_io.h"
#include "utils.h"
using
namespace
cute
;
using
namespace
cute
;
...
@@ -178,8 +179,205 @@ void run_get_group_gemm_starts(
...
@@ -178,8 +179,205 @@ void run_get_group_gemm_starts(
}
}
}
}
void
run_fp4_blockwise_scaled_group_mm_sm120
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
ab_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
cutlass
::
bfloat16_t
;
using
ElementD
=
cutlass
::
bfloat16_t
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
FusionOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementAccumulator
,
ElementC
,
ElementAccumulator
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
FusionOperation
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
ThreadBlockShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpong
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
cutlass
::
gemm
::
kernel
::
detail
::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
fusion_args
.
beta
=
0.0
f
;
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
void
run_fp4_blockwise_scaled_group_mm
_sm100
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
b
,
...
@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm(
...
@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
}
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
...
@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm(
...
@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm(
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
auto
sm_version
=
getSMVersion
();
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
if
(
sm_version
==
100
||
sm_version
==
103
)
{
output
,
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
a
,
run_fp4_blockwise_scaled_group_mm_sm100
<
cutlass
::
bfloat16_t
>
(
b
,
output
,
a_blockscale
,
a
,
b_blockscales
,
b
,
alphas
,
a_blockscale
,
ab_strides
,
b_blockscales
,
c_strides
,
alphas
,
problem_sizes
,
ab_strides
,
expert_offsets
,
c_strides
,
sf_offsets
,
problem_sizes
,
M
,
expert_offsets
,
N
,
sf_offsets
,
K
);
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm_sm100
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
}
else
if
(
sm_version
==
120
)
{
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm_sm120
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
std
::
cout
<<
"run_fp4_blockwise_scaled_group_mm_sm120 half no implementation"
<<
std
::
endl
;
}
}
else
{
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Unsupported SM version: "
+
std
::
to_string
(
sm_version
));
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
}
#else
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment