Unverified Commit ed1044ac authored by AichenF's avatar AichenF Committed by GitHub
Browse files

support cutlass fp4 kernel in sm120 (#11737)

parent d717e73e
...@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; ...@@ -51,7 +51,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// PTX instructions used here requires >= sm100f. // PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \ #if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000)) (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t val; uint32_t val;
asm volatile( asm volatile(
...@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { ...@@ -86,7 +86,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
// PTX instructions used here requires >= sm100f. // PTX instructions used here requires >= sm100f.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || \ #if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \
(defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000)) (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ > 1000))
uint32_t val; uint32_t val;
asm volatile( asm volatile(
......
...@@ -16,8 +16,11 @@ limitations under the License. ...@@ -16,8 +16,11 @@ limitations under the License.
#include <torch/all.h> #include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_quant_sm100a( void scaled_fp4_quant_sm100a_sm120a(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf);
void scaled_fp4_experts_quant_sm100a( void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output,
...@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -40,7 +43,7 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
void scaled_fp4_quant( void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); return scaled_fp4_quant_sm100a_sm120a(output, input, output_sf, input_sf);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
} }
......
...@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() { ...@@ -199,8 +199,11 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls return multi_processor_count; // Return the cached value on subsequent calls
} }
void scaled_fp4_quant_sm100a( void scaled_fp4_quant_sm100a_sm120a(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf) {
auto sm_version = getSMVersion(); auto sm_version = getSMVersion();
TORCH_CHECK(sm_version >= 100, "fp4_quant is only supported on sm100+"); TORCH_CHECK(sm_version >= 100, "fp4_quant is only supported on sm100+");
......
...@@ -16,13 +16,38 @@ limitations under the License. ...@@ -16,13 +16,38 @@ limitations under the License.
#include <torch/all.h> #include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
void cutlass_scaled_fp4_mm_sm100a( void cutlass_scaled_fp4_mm_sm100a_sm120a(
torch::Tensor& D, torch::Tensor& D,
torch::Tensor const& A, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& B,
torch::Tensor const& A_sf, torch::Tensor const& A_sf,
torch::Tensor const& B_sf, torch::Tensor const& B_sf,
torch::Tensor const& alpha); torch::Tensor const& alpha);
// SM120 specific dispatch functions
void cutlass_fp4_bf16_gemm_dispatch_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int m,
int n,
int k,
cudaStream_t stream);
void cutlass_fp4_f16_gemm_dispatch_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int m,
int n,
int k,
cudaStream_t stream);
#endif #endif
void cutlass_scaled_fp4_mm( void cutlass_scaled_fp4_mm(
...@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm( ...@@ -33,7 +58,7 @@ void cutlass_scaled_fp4_mm(
torch::Tensor const& B_sf, torch::Tensor const& B_sf,
torch::Tensor const& alpha) { torch::Tensor const& alpha) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); return cutlass_scaled_fp4_mm_sm100a_sm120a(D, A, B, A_sf, B_sf, alpha);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel."); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel.");
} }
...@@ -17,6 +17,8 @@ limitations under the License. ...@@ -17,6 +17,8 @@ limitations under the License.
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/all.h> #include <torch/all.h>
#include "utils.h"
// clang-format off // clang-format off
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
...@@ -37,7 +39,20 @@ limitations under the License. ...@@ -37,7 +39,20 @@ limitations under the License.
using namespace cute; using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // Helper function for next power of 2
inline uint32_t next_pow_2(uint32_t x) {
if (x == 0) return 1;
x--;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x + 1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \
defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Config(half_t/bfloat16_t) for M <= 128 // Config(half_t/bfloat16_t) for M <= 128
template <typename T> template <typename T>
struct KernelConfigM128 { struct KernelConfigM128 {
...@@ -102,6 +117,19 @@ struct KernelConfigFp32 { ...@@ -102,6 +117,19 @@ struct KernelConfigFp32 {
const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);
const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);
// SM120 specific configurations
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape_MNK = Shape<_128, _128, _128>;
};
struct sm120_fp4_config_default {
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_256, _128, _128>;
using PerSmTileShape_MNK = Shape<_256, _128, _128>;
};
template <typename KernelConfig> template <typename KernelConfig>
struct Fp4GemmSm100 { struct Fp4GemmSm100 {
using Config = KernelConfig; // For generating args using Config = KernelConfig; // For generating args
...@@ -183,6 +211,70 @@ struct Fp4GemmSm100 { ...@@ -183,6 +211,70 @@ struct Fp4GemmSm100 {
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
}; };
// SM120 specific GEMM template
template <typename Config, typename OutType>
struct Fp4GemmSm120 {
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutATag = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 32;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutBTag = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 32;
using ElementD = OutType;
using ElementC = OutType;
using LayoutCTag = cutlass::layout::RowMajor;
using LayoutDTag = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using MmaTileShape = typename Config::MmaTileShape;
using ClusterShape = typename Config::ClusterShape;
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
PerSmTileShape_MNK,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutCTag,
AlignmentC,
ElementD,
LayoutDTag,
AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
LayoutATag,
AlignmentA,
ElementB,
LayoutBTag,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename T> template <typename T>
typename T::Gemm::Arguments args_from_options( typename T::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor& D,
...@@ -267,6 +359,85 @@ void runGemm( ...@@ -267,6 +359,85 @@ void runGemm(
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
// SM120 specific args_from_options function
template <typename Gemm>
typename Gemm::Arguments args_from_options_sm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha,
int M,
int N,
int K) {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
using ElementSFA = cutlass::float_ue4m3_t;
using ElementSFB = cutlass::float_ue4m3_t;
using ElementCompute = float;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{static_cast<ElementA const*>(A.data_ptr()),
stride_A,
static_cast<ElementB const*>(B.data_ptr()),
stride_B,
static_cast<ElementSFA const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<ElementSFB const*>(B_sf.data_ptr()),
layout_SFB},
{{}, static_cast<ElementD const*>(D.data_ptr()), stride_D, static_cast<ElementD*>(D.data_ptr()), stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
return arguments;
}
// SM120 specific runGemm function
template <typename Gemm>
void runGemmSm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha,
int M,
int N,
int K,
cudaStream_t stream) {
Gemm gemm;
auto arguments = args_from_options_sm120<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
// Dispatch function to select appropriate config based on M // Dispatch function to select appropriate config based on M
template <typename OutType> template <typename OutType>
void cutlassFp4GemmDispatch( void cutlassFp4GemmDispatch(
...@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>( ...@@ -308,6 +479,49 @@ void cutlassFp4GemmDispatch<float>(
runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} }
// SM120 specific dispatch functions
void cutlass_fp4_bf16_gemm_dispatch_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int m,
int n,
int k,
cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemmSm120<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemmSm120<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
void cutlass_fp4_f16_gemm_dispatch_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int m,
int n,
int k,
cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemmSm120<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemmSm120<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
#else #else
template <typename T> template <typename T>
void cutlassFp4GemmDispatch( void cutlassFp4GemmDispatch(
...@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch( ...@@ -326,7 +540,12 @@ void cutlassFp4GemmDispatch(
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
} }
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) ||
// defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) #define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
...@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch( ...@@ -339,7 +558,7 @@ void cutlassFp4GemmDispatch(
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a( void cutlass_scaled_fp4_mm_sm100a_sm120a(
torch::Tensor& D, torch::Tensor& D,
torch::Tensor const& A, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& B,
...@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -441,13 +660,28 @@ void cutlass_scaled_fp4_mm_sm100a(
at::cuda::CUDAGuard device_guard{(char)A.get_device()}; at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) { // Check SM version and dispatch accordingly
cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); auto sm_version = getSMVersion();
} else if (out_dtype == at::ScalarType::BFloat16) {
cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); if (sm_version == 120) {
} else if (out_dtype == at::ScalarType::Float) { // Use SM120 specific dispatch
cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); if (out_dtype == at::ScalarType::Half) {
cutlass_fp4_f16_gemm_dispatch_sm120(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
cutlass_fp4_bf16_gemm_dispatch_sm120(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", out_dtype, ")");
}
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); // Use SM100 dispatch for other architectures
if (out_dtype == at::ScalarType::Half) {
cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::Float) {
cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
}
} }
} }
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h" #include "cutlass/util/tensor_view_io.h"
#include "utils.h"
using namespace cute; using namespace cute;
...@@ -178,8 +179,205 @@ void run_get_group_gemm_starts( ...@@ -178,8 +179,205 @@ void run_get_group_gemm_starts(
} }
} }
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets,
int M,
int N,
int K) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = cutlass::bfloat16_t;
using ElementD = cutlass::bfloat16_t;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using ThreadBlockShape = Shape<_128, _128, _128>;
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
using FusionOperation =
cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ThreadBlockShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto,
FusionOperation>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
LayoutA*,
AlignmentA,
ElementB,
LayoutB*,
AlignmentB,
ElementAccumulator,
ThreadBlockShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
alpha_ptrs,
layout_sfa,
layout_sfb,
a,
b,
output,
a_blockscale,
b_blockscales,
alphas,
expert_offsets,
sf_offsets,
problem_sizes,
M,
N,
K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(ab_strides.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(ab_strides.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
fusion_args.beta = 0.0f;
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm( void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, torch::Tensor& output,
const torch::Tensor& a, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& b,
...@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm( ...@@ -376,6 +574,10 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
// Undefine macros from utils.h to redefine with custom signatures
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) #define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") #define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
...@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm( ...@@ -428,38 +630,63 @@ void cutlass_fp4_group_mm(
int E = static_cast<int>(b.size(0)); int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2)); int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) { auto sm_version = getSMVersion();
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>( if (sm_version == 100 || sm_version == 103) {
output, if (output.scalar_type() == torch::kBFloat16) {
a, run_fp4_blockwise_scaled_group_mm_sm100<cutlass::bfloat16_t>(
b, output,
a_blockscale, a,
b_blockscales, b,
alphas, a_blockscale,
ab_strides, b_blockscales,
c_strides, alphas,
problem_sizes, ab_strides,
expert_offsets, c_strides,
sf_offsets, problem_sizes,
M, expert_offsets,
N, sf_offsets,
K); M,
N,
K);
} else {
run_fp4_blockwise_scaled_group_mm_sm100<cutlass::half_t>(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
}
} else if (sm_version == 120) {
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm_sm120(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
} else {
std::cout << "run_fp4_blockwise_scaled_group_mm_sm120 half no implementation" << std::endl;
}
} else { } else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>( TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported SM version: " + std::to_string(sm_version));
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
} }
#else #else
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment