Unverified Commit 85ed8e0a authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

Optimize nvfp4 block scaled gemm kernel when M is small. (#10101)

parent dd1e2689
...@@ -38,27 +38,74 @@ limitations under the License. ...@@ -38,27 +38,74 @@ limitations under the License.
using namespace cute; using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config // Config(half_t/bfloat16_t) for M <= 128
template <typename T> template <typename T>
struct KernelTraits { struct KernelConfigM128 {
using OutputType = T;
using MmaTileShape = Shape<_128, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
template <typename T>
const dim3 KernelConfigM128<T>::preferred_cluster(1, 4, 1);
template <typename T>
const dim3 KernelConfigM128<T>::fallback_cluster(1, 2, 1);
// Config(half_t/bfloat16_t) for M <= 256
template <typename T>
struct KernelConfigM256 {
using OutputType = T;
using MmaTileShape = Shape<_256, _256, _256>; using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<int, int, _1>; using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>; using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
}; };
template <typename T>
const dim3 KernelConfigM256<T>::preferred_cluster(2, 4, 1);
template <typename T>
const dim3 KernelConfigM256<T>::fallback_cluster(2, 1, 1);
template <> // Default config(half_t/bfloat16_t) for M > 256
struct KernelTraits<float> { template <typename T>
struct KernelConfigDefault {
using OutputType = T;
using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
template <typename T>
const dim3 KernelConfigDefault<T>::preferred_cluster(4, 4, 1);
template <typename T>
const dim3 KernelConfigDefault<T>::fallback_cluster(2, 1, 1);
struct KernelConfigFp32 {
using OutputType = float;
using MmaTileShape = Shape<_128, _128, _256>; using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<int, int, _1>; using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
}; };
const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);
const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);
template <typename T> template <typename KernelConfig>
struct Fp4GemmSm100 { struct Fp4GemmSm100 {
using Config = KernelConfig; // For generating args
using OutputType = typename KernelConfig::OutputType;
// A matrix configuration // A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutATag = cutlass::layout::RowMajor; using LayoutATag = cutlass::layout::RowMajor;
...@@ -70,8 +117,8 @@ struct Fp4GemmSm100 { ...@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
static constexpr int AlignmentB = 32; static constexpr int AlignmentB = 32;
// C/D matrix configuration // C/D matrix configuration
using ElementD = T; using ElementD = OutputType;
using ElementC = T; using ElementC = OutputType;
using LayoutCTag = cutlass::layout::RowMajor; using LayoutCTag = cutlass::layout::RowMajor;
using LayoutDTag = cutlass::layout::RowMajor; using LayoutDTag = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
...@@ -82,15 +129,15 @@ struct Fp4GemmSm100 { ...@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
// Kernel Perf config // Kernel Perf config
using MmaTileShape = typename KernelTraits<T>::MmaTileShape; using MmaTileShape = typename KernelConfig::MmaTileShape;
using ClusterShape = typename KernelTraits<T>::ClusterShape; using ClusterShape = typename KernelConfig::ClusterShape;
using EpilogueTile = typename KernelTraits<T>::EpilogueTile; using EpilogueTile = typename KernelConfig::EpilogueTile;
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule; using EpilogueSchedule = typename KernelConfig::EpilogueSchedule;
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule; using MainloopSchedule = typename KernelConfig::MainloopSchedule;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
cutlass::arch::OpClassTensorOp, OperatorClass,
MmaTileShape, MmaTileShape,
ClusterShape, ClusterShape,
EpilogueTile, EpilogueTile,
...@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options( ...@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
layout_SFB}, layout_SFB},
{ // Epilogue arguments { // Epilogue arguments
{}, // epilogue.thread {}, // epilogue.thread
static_cast<ElementD const*>(D.data_ptr()), nullptr,
stride_D, stride_D,
static_cast<ElementD*>(D.data_ptr()), static_cast<ElementD*>(D.data_ptr()),
stride_D}}; stride_D}};
auto& fusion_args = arguments.epilogue.thread; auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr()); fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
if constexpr (std::is_same_v<T, float>) { using KernelConfig = typename T::Config;
arguments.hw_info.cluster_shape = dim3(1, 4, 1); arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster;
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1); arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster;
} else {
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
}
return arguments; return arguments;
} }
...@@ -210,11 +253,10 @@ void runGemm( ...@@ -210,11 +253,10 @@ void runGemm(
int64_t n, int64_t n,
int64_t k, int64_t k,
cudaStream_t stream) { cudaStream_t stream) {
typename Fp4GemmSm100<T>::Gemm gemm; typename T::Gemm gemm;
auto arguments = args_from_options<T>(D, A, B, A_sf, B_sf, alpha, m, n, k);
auto arguments = args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments); size_t workspace_size = T::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options); auto workspace = torch::empty(workspace_size, workspace_options);
...@@ -224,9 +266,51 @@ void runGemm( ...@@ -224,9 +266,51 @@ void runGemm(
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlassFp4GemmDispatch(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
if (m <= 128) {
// m in [1, 128]
runGemm<Fp4GemmSm100<KernelConfigM128<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (m <= 256) {
// m in (128, 256]
runGemm<Fp4GemmSm100<KernelConfigM256<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
// m in (256, inf)
runGemm<Fp4GemmSm100<KernelConfigDefault<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
// Dispatch function to select appropriate config based on M
template <>
void cutlassFp4GemmDispatch<float>(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
#else #else
template <typename T> template <typename T>
void runGemm( void cutlassFp4GemmDispatch(
at::Tensor& D, at::Tensor& D,
at::Tensor const& A, at::Tensor const& A,
at::Tensor const& B, at::Tensor const& B,
...@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) { if (out_dtype == at::ScalarType::Half) {
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) { } else if (out_dtype == at::ScalarType::BFloat16) {
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::Float) { } else if (out_dtype == at::ScalarType::Float) {
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment