Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed1044ac
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b3e7a2cee4c811122db363bb4d8fd56121a59cf9"
Unverified
Commit
ed1044ac
authored
Oct 30, 2025
by
AichenF
Committed by
GitHub
Oct 29, 2025
Browse files
support cutlass fp4 kernel in sm120 (#11737)
parent
d717e73e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
543 additions
and
51 deletions
+543
-51
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
+2
-2
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+6
-3
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
+5
-2
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
+27
-2
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+244
-10
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
+259
-32
No files found.
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
View file @
ed1044ac
...
@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
...
@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires >= sm100f.
// PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED ||
CUTLASS_ARCH_MMA_SM120A_ENABLED ||
\
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
...
@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
...
@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires >= sm100f.
// PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED ||
CUTLASS_ARCH_MMA_SM120A_ENABLED ||
\
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
View file @
ed1044ac
...
@@ -16,8 +16,11 @@ limitations under the License.
...
@@ -16,8 +16,11 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_quant_sm100a
(
void
scaled_fp4_quant_sm100a_sm120a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
void
scaled_fp4_experts_quant_sm100a
(
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
...
@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
return
scaled_fp4_quant_sm100a
_sm120a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
}
}
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
View file @
ed1044ac
...
@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() {
...
@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() {
return
multi_processor_count
;
// Return the cached value on subsequent calls
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
}
void
scaled_fp4_quant_sm100a
(
void
scaled_fp4_quant_sm100a_sm120a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
auto
sm_version
=
getSMVersion
();
auto
sm_version
=
getSMVersion
();
TORCH_CHECK
(
sm_version
>=
100
,
"fp4_quant is only supported on sm100+"
);
TORCH_CHECK
(
sm_version
>=
100
,
"fp4_quant is only supported on sm100+"
);
...
...
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
View file @
ed1044ac
...
@@ -16,13 +16,38 @@ limitations under the License.
...
@@ -16,13 +16,38 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
cutlass_scaled_fp4_mm_sm100a
(
void
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
torch
::
Tensor
const
&
alpha
);
// SM120 specific dispatch functions
void
cutlass_fp4_bf16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
);
void
cutlass_fp4_f16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
);
#endif
#endif
void
cutlass_scaled_fp4_mm
(
void
cutlass_scaled_fp4_mm
(
...
@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm(
...
@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm(
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
torch
::
Tensor
const
&
alpha
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
cutlass_scaled_fp4_mm_sm100a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
return
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel."
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel."
);
}
}
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
View file @
ed1044ac
...
@@ -17,6 +17,8 @@ limitations under the License.
...
@@ -17,6 +17,8 @@ limitations under the License.
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/all.h>
#include "utils.h"
// clang-format off
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
...
@@ -37,7 +39,20 @@ limitations under the License.
...
@@ -37,7 +39,20 @@ limitations under the License.
using
namespace
cute
;
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Helper function for next power of 2
inline
uint32_t
next_pow_2
(
uint32_t
x
)
{
if
(
x
==
0
)
return
1
;
x
--
;
x
|=
x
>>
1
;
x
|=
x
>>
2
;
x
|=
x
>>
4
;
x
|=
x
>>
8
;
x
|=
x
>>
16
;
return
x
+
1
;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \
defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Config(half_t/bfloat16_t) for M <= 128
// Config(half_t/bfloat16_t) for M <= 128
template
<
typename
T
>
template
<
typename
T
>
struct
KernelConfigM128
{
struct
KernelConfigM128
{
...
@@ -102,6 +117,19 @@ struct KernelConfigFp32 {
...
@@ -102,6 +117,19 @@ struct KernelConfigFp32 {
const
dim3
KernelConfigFp32
::
preferred_cluster
=
dim3
(
1
,
4
,
1
);
const
dim3
KernelConfigFp32
::
preferred_cluster
=
dim3
(
1
,
4
,
1
);
const
dim3
KernelConfigFp32
::
fallback_cluster
=
dim3
(
1
,
2
,
1
);
const
dim3
KernelConfigFp32
::
fallback_cluster
=
dim3
(
1
,
2
,
1
);
// SM120 specific configurations
struct
sm120_fp4_config_M256
{
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_128
>
;
};
struct
sm120_fp4_config_default
{
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
MmaTileShape
=
Shape
<
_256
,
_128
,
_128
>
;
using
PerSmTileShape_MNK
=
Shape
<
_256
,
_128
,
_128
>
;
};
template
<
typename
KernelConfig
>
template
<
typename
KernelConfig
>
struct
Fp4GemmSm100
{
struct
Fp4GemmSm100
{
using
Config
=
KernelConfig
;
// For generating args
using
Config
=
KernelConfig
;
// For generating args
...
@@ -183,6 +211,70 @@ struct Fp4GemmSm100 {
...
@@ -183,6 +211,70 @@ struct Fp4GemmSm100 {
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
};
};
// SM120 specific GEMM template
template
<
typename
Config
,
typename
OutType
>
struct
Fp4GemmSm120
{
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
32
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
32
;
using
ElementD
=
OutType
;
using
ElementC
=
OutType
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
MmaTileShape
=
typename
Config
::
MmaTileShape
;
using
ClusterShape
=
typename
Config
::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
Config
::
PerSmTileShape_MNK
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape_MNK
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutATag
,
AlignmentA
,
ElementB
,
LayoutBTag
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
template
<
typename
T
>
template
<
typename
T
>
typename
T
::
Gemm
::
Arguments
args_from_options
(
typename
T
::
Gemm
::
Arguments
args_from_options
(
at
::
Tensor
&
D
,
at
::
Tensor
&
D
,
...
@@ -267,6 +359,85 @@ void runGemm(
...
@@ -267,6 +359,85 @@ void runGemm(
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
}
// SM120 specific args_from_options function
template
<
typename
Gemm
>
typename
Gemm
::
Arguments
args_from_options_sm120
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
M
,
int
N
,
int
K
)
{
using
ElementA
=
typename
Gemm
::
ElementA
;
using
ElementB
=
typename
Gemm
::
ElementB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementSFA
=
cutlass
::
float_ue4m3_t
;
using
ElementSFB
=
cutlass
::
float_ue4m3_t
;
using
ElementCompute
=
float
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
Sm1xxBlkScaledConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
auto
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
{
M
,
K
,
1
});
auto
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
N
,
K
,
1
});
auto
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
{
M
,
N
,
1
});
auto
layout_SFA
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
M
,
N
,
K
,
1
));
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
M
,
N
,
K
,
1
));
typename
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
M
,
N
,
K
,
1
},
{
static_cast
<
ElementA
const
*>
(
A
.
data_ptr
()),
stride_A
,
static_cast
<
ElementB
const
*>
(
B
.
data_ptr
()),
stride_B
,
static_cast
<
ElementSFA
const
*>
(
A_sf
.
data_ptr
()),
layout_SFA
,
static_cast
<
ElementSFB
const
*>
(
B_sf
.
data_ptr
()),
layout_SFB
},
{{},
static_cast
<
ElementD
const
*>
(
D
.
data_ptr
()),
stride_D
,
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
return
arguments
;
}
// SM120 specific runGemm function
template
<
typename
Gemm
>
void
runGemmSm120
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
M
,
int
N
,
int
K
,
cudaStream_t
stream
)
{
Gemm
gemm
;
auto
arguments
=
args_from_options_sm120
<
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
M
,
N
,
K
);
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
// Dispatch function to select appropriate config based on M
// Dispatch function to select appropriate config based on M
template
<
typename
OutType
>
template
<
typename
OutType
>
void
cutlassFp4GemmDispatch
(
void
cutlassFp4GemmDispatch
(
...
@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>(
...
@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>(
runGemm
<
Fp4GemmSm100
<
KernelConfigFp32
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
runGemm
<
Fp4GemmSm100
<
KernelConfigFp32
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
// SM120 specific dispatch functions
void
cutlass_fp4_bf16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
)
{
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
mp2
<=
256
)
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_M256
,
cutlass
::
bfloat16_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_default
,
cutlass
::
bfloat16_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
void
cutlass_fp4_f16_gemm_dispatch_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int
m
,
int
n
,
int
k
,
cudaStream_t
stream
)
{
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
mp2
<=
256
)
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_M256
,
cutlass
::
half_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
runGemmSm120
<
Fp4GemmSm120
<
sm120_fp4_config_default
,
cutlass
::
half_t
>::
Gemm
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
#else
#else
template
<
typename
T
>
template
<
typename
T
>
void
cutlassFp4GemmDispatch
(
void
cutlassFp4GemmDispatch
(
...
@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch(
...
@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch(
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."
);
"a CUTLASS 3.8 source directory to enable support."
);
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) ||
// defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
...
@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch(
...
@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch(
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
void
cutlass_scaled_fp4_mm_sm100a
(
void
cutlass_scaled_fp4_mm_sm100a
_sm120a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
B
,
...
@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a(
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
// Check SM version and dispatch accordingly
cutlassFp4GemmDispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
auto
sm_version
=
getSMVersion
();
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlassFp4GemmDispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
if
(
sm_version
==
120
)
{
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
// Use SM120 specific dispatch
cutlassFp4GemmDispatch
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
cutlass_fp4_f16_gemm_dispatch_sm120
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlass_fp4_bf16_gemm_dispatch_sm120
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm sm120 ("
,
out_dtype
,
")"
);
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
// Use SM100 dispatch for other architectures
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
cutlassFp4GemmDispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
cutlassFp4GemmDispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
cutlassFp4GemmDispatch
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
}
}
}
}
}
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
View file @
ed1044ac
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/tensor_view_io.h"
#include "utils.h"
using
namespace
cute
;
using
namespace
cute
;
...
@@ -178,8 +179,205 @@ void run_get_group_gemm_starts(
...
@@ -178,8 +179,205 @@ void run_get_group_gemm_starts(
}
}
}
}
void
run_fp4_blockwise_scaled_group_mm_sm120
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
ab_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
cutlass
::
bfloat16_t
;
using
ElementD
=
cutlass
::
bfloat16_t
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
FusionOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementAccumulator
,
ElementC
,
ElementAccumulator
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
FusionOperation
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
ThreadBlockShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpong
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
cutlass
::
gemm
::
kernel
::
detail
::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
fusion_args
.
beta
=
0.0
f
;
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
void
run_fp4_blockwise_scaled_group_mm
_sm100
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
b
,
...
@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm(
...
@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
}
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
...
@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm(
...
@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm(
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
auto
sm_version
=
getSMVersion
();
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
if
(
sm_version
==
100
||
sm_version
==
103
)
{
output
,
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
a
,
run_fp4_blockwise_scaled_group_mm_sm100
<
cutlass
::
bfloat16_t
>
(
b
,
output
,
a_blockscale
,
a
,
b_blockscales
,
b
,
alphas
,
a_blockscale
,
ab_strides
,
b_blockscales
,
c_strides
,
alphas
,
problem_sizes
,
ab_strides
,
expert_offsets
,
c_strides
,
sf_offsets
,
problem_sizes
,
M
,
expert_offsets
,
N
,
sf_offsets
,
K
);
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm_sm100
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
}
else
if
(
sm_version
==
120
)
{
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm_sm120
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
std
::
cout
<<
"run_fp4_blockwise_scaled_group_mm_sm120 half no implementation"
<<
std
::
endl
;
}
}
else
{
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Unsupported SM version: "
+
std
::
to_string
(
sm_version
));
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
}
#else
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment