Unverified Commit 85ed8e0a authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

Optimize nvfp4 block scaled gemm kernel when M is small. (#10101)

parent dd1e2689
......@@ -38,27 +38,74 @@ limitations under the License.
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
// Config(half_t/bfloat16_t) for M <= 128
template <typename T>
struct KernelTraits {
struct KernelConfigM128 {
using OutputType = T;
using MmaTileShape = Shape<_128, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
template <typename T>
const dim3 KernelConfigM128<T>::preferred_cluster(1, 4, 1);
template <typename T>
const dim3 KernelConfigM128<T>::fallback_cluster(1, 2, 1);
// Config(half_t/bfloat16_t) for M <= 256
template <typename T>
struct KernelConfigM256 {
using OutputType = T;
using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>;
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
template <typename T>
const dim3 KernelConfigM256<T>::preferred_cluster(2, 4, 1);
template <typename T>
const dim3 KernelConfigM256<T>::fallback_cluster(2, 1, 1);
template <>
struct KernelTraits<float> {
// Default config(half_t/bfloat16_t) for M > 256
template <typename T>
struct KernelConfigDefault {
using OutputType = T;
using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
template <typename T>
const dim3 KernelConfigDefault<T>::preferred_cluster(4, 4, 1);
template <typename T>
const dim3 KernelConfigDefault<T>::fallback_cluster(2, 1, 1);
struct KernelConfigFp32 {
using OutputType = float;
using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);
const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);
template <typename T>
template <typename KernelConfig>
struct Fp4GemmSm100 {
using Config = KernelConfig; // For generating args
using OutputType = typename KernelConfig::OutputType;
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutATag = cutlass::layout::RowMajor;
......@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
static constexpr int AlignmentB = 32;
// C/D matrix configuration
using ElementD = T;
using ElementC = T;
using ElementD = OutputType;
using ElementC = OutputType;
using LayoutCTag = cutlass::layout::RowMajor;
using LayoutDTag = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
......@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
// Kernel Perf config
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
using ClusterShape = typename KernelTraits<T>::ClusterShape;
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
using MmaTileShape = typename KernelConfig::MmaTileShape;
using ClusterShape = typename KernelConfig::ClusterShape;
using EpilogueTile = typename KernelConfig::EpilogueTile;
using EpilogueSchedule = typename KernelConfig::EpilogueSchedule;
using MainloopSchedule = typename KernelConfig::MainloopSchedule;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
cutlass::arch::OpClassTensorOp,
OperatorClass,
MmaTileShape,
ClusterShape,
EpilogueTile,
......@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<ElementD const*>(D.data_ptr()),
nullptr,
stride_D,
static_cast<ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
if constexpr (std::is_same_v<T, float>) {
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
} else {
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
}
using KernelConfig = typename T::Config;
arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster;
arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster;
return arguments;
}
......@@ -210,11 +253,10 @@ void runGemm(
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Fp4GemmSm100<T>::Gemm gemm;
auto arguments = args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
typename T::Gemm gemm;
auto arguments = args_from_options<T>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
size_t workspace_size = T::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
......@@ -224,9 +266,51 @@ void runGemm(
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlassFp4GemmDispatch(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
if (m <= 128) {
// m in [1, 128]
runGemm<Fp4GemmSm100<KernelConfigM128<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (m <= 256) {
// m in (128, 256]
runGemm<Fp4GemmSm100<KernelConfigM256<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
// m in (256, inf)
runGemm<Fp4GemmSm100<KernelConfigDefault<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
// Dispatch function to select appropriate config based on M
template <>
void cutlassFp4GemmDispatch<float>(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
#else
template <typename T>
void runGemm(
void cutlassFp4GemmDispatch(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
......@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) {
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == at::ScalarType::Float) {
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment