Unverified Commit 089cd4f0 authored by Lain's avatar Lain Committed by GitHub
Browse files

fix cutlass_3x_gemm_fp8_blockwise on sm103a (#32224)


Signed-off-by: default avatarSiyuan Fu <siyuanf@nvidia.com>
Co-authored-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 0130223b
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include <climits> #include <climits>
#include "cuda_runtime.h" #include "cuda_runtime.h"
#include <iostream> #include <cstdio>
#include <cstdlib>
/** /**
* Helper function for checking CUTLASS errors * Helper function for checking CUTLASS errors
...@@ -31,12 +32,63 @@ int32_t get_sm_version_num(); ...@@ -31,12 +32,63 @@ int32_t get_sm_version_num();
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined. * into code that will be executed on the device where it is defined.
*/ */
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[75, 80).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[80, 89).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[89, 90).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel> template <typename Kernel>
struct enable_sm90_or_later : Kernel { struct enable_sm90_or_later : Kernel {
template <typename... Args> template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) { CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 #if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...); Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm >= 90.\n");
asm("trap;");
#endif
#endif #endif
} }
}; };
...@@ -45,18 +97,43 @@ template <typename Kernel> ...@@ -45,18 +97,43 @@ template <typename Kernel>
struct enable_sm90_only : Kernel { struct enable_sm90_only : Kernel {
template <typename... Args> template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) { CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 #if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 900
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm90.\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm100f_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
Kernel::operator()(std::forward<Args>(args)...); Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm100f.\n");
asm("trap;");
#endif
#endif #endif
} }
}; };
template <typename Kernel> template <typename Kernel>
struct enable_sm100_only : Kernel { struct enable_sm100a_only : Kernel {
template <typename... Args> template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) { CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 #if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...); Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm100a.\n");
asm("trap;");
#endif
#endif #endif
} }
}; };
...@@ -65,8 +142,13 @@ template <typename Kernel> ...@@ -65,8 +142,13 @@ template <typename Kernel>
struct enable_sm120_only : Kernel { struct enable_sm120_only : Kernel {
template <typename... Args> template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) { CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200 #if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1200
Kernel::operator()(std::forward<Args>(args)...); Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm120.\n");
asm("trap;");
#endif
#endif #endif
} }
}; };
...@@ -141,8 +141,8 @@ struct cutlass_3x_gemm_sm100 { ...@@ -141,8 +141,8 @@ struct cutlass_3x_gemm_sm100 {
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
}; };
template <typename ElementAB_, typename ElementD_, template <typename ElementAB_, typename ElementD_,
...@@ -202,8 +202,8 @@ struct cutlass_3x_gemm_sm120 { ...@@ -202,8 +202,8 @@ struct cutlass_3x_gemm_sm120 {
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
}; };
} // namespace vllm } // namespace vllm
...@@ -123,7 +123,7 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -123,7 +123,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
MainloopScheduler MainloopScheduler
>::CollectiveOp>; >::CollectiveOp>;
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>; Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
......
...@@ -90,8 +90,8 @@ struct cutlass_3x_gemm_sm100_fp8 { ...@@ -90,8 +90,8 @@ struct cutlass_3x_gemm_sm100_fp8 {
// ----------------------------------------------------------- // -----------------------------------------------------------
// Kernel definition // Kernel definition
// ----------------------------------------------------------- // -----------------------------------------------------------
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
}; };
template <typename InType, typename OutType, bool EnableBias> template <typename InType, typename OutType, bool EnableBias>
......
...@@ -36,41 +36,6 @@ using namespace cute; ...@@ -36,41 +36,6 @@ using namespace cute;
*/ */
namespace vllm { namespace vllm {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard, template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape, template <typename, typename> typename Epilogue_, typename TileShape,
......
...@@ -50,7 +50,7 @@ struct sm89_fp8_config_default { ...@@ -50,7 +50,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -58,7 +58,7 @@ struct sm89_fp8_config_default { ...@@ -58,7 +58,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>, InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -67,7 +67,7 @@ struct sm89_fp8_config_default { ...@@ -67,7 +67,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -100,7 +100,7 @@ struct sm89_fp8_config_M256 { ...@@ -100,7 +100,7 @@ struct sm89_fp8_config_M256 {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>, InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -108,7 +108,7 @@ struct sm89_fp8_config_M256 { ...@@ -108,7 +108,7 @@ struct sm89_fp8_config_M256 {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -141,7 +141,7 @@ struct sm89_fp8_config_M128 { ...@@ -141,7 +141,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>, InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -150,7 +150,7 @@ struct sm89_fp8_config_M128 { ...@@ -150,7 +150,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -158,7 +158,7 @@ struct sm89_fp8_config_M128 { ...@@ -158,7 +158,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>, InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -191,7 +191,7 @@ struct sm89_fp8_config_M64 { ...@@ -191,7 +191,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd; using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -201,7 +201,7 @@ struct sm89_fp8_config_M64 { ...@@ -201,7 +201,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>, InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -211,7 +211,7 @@ struct sm89_fp8_config_M64 { ...@@ -211,7 +211,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd; using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -244,7 +244,7 @@ struct sm89_fp8_config_M32 { ...@@ -244,7 +244,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -253,7 +253,7 @@ struct sm89_fp8_config_M32 { ...@@ -253,7 +253,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>, InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -262,7 +262,7 @@ struct sm89_fp8_config_M32 { ...@@ -262,7 +262,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>, InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -296,7 +296,7 @@ struct sm89_fp8_config_M16 { ...@@ -296,7 +296,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages, InstructionShape, MainLoopStages,
FP8MathOperator>, FP8MathOperator>,
...@@ -305,7 +305,7 @@ struct sm89_fp8_config_M16 { ...@@ -305,7 +305,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages, InstructionShape, MainLoopStages,
FP8MathOperator>, FP8MathOperator>,
...@@ -314,7 +314,7 @@ struct sm89_fp8_config_M16 { ...@@ -314,7 +314,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages, InstructionShape, MainLoopStages,
FP8MathOperator>, FP8MathOperator>,
......
...@@ -48,7 +48,7 @@ struct sm89_int8_config_default { ...@@ -48,7 +48,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -56,7 +56,7 @@ struct sm89_int8_config_default { ...@@ -56,7 +56,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -64,7 +64,7 @@ struct sm89_int8_config_default { ...@@ -64,7 +64,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -72,7 +72,7 @@ struct sm89_int8_config_default { ...@@ -72,7 +72,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -104,7 +104,7 @@ struct sm89_int8_config_M256 { ...@@ -104,7 +104,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>; using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -112,7 +112,7 @@ struct sm89_int8_config_M256 { ...@@ -112,7 +112,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -120,7 +120,7 @@ struct sm89_int8_config_M256 { ...@@ -120,7 +120,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -128,7 +128,7 @@ struct sm89_int8_config_M256 { ...@@ -128,7 +128,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -160,7 +160,7 @@ struct sm89_int8_config_M128 { ...@@ -160,7 +160,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -169,7 +169,7 @@ struct sm89_int8_config_M128 { ...@@ -169,7 +169,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -178,7 +178,7 @@ struct sm89_int8_config_M128 { ...@@ -178,7 +178,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -210,7 +210,7 @@ struct sm89_int8_config_M64 { ...@@ -210,7 +210,7 @@ struct sm89_int8_config_M64 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -219,7 +219,7 @@ struct sm89_int8_config_M64 { ...@@ -219,7 +219,7 @@ struct sm89_int8_config_M64 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>, InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -251,7 +251,7 @@ struct sm89_int8_config_M32 { ...@@ -251,7 +251,7 @@ struct sm89_int8_config_M32 {
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -260,7 +260,7 @@ struct sm89_int8_config_M32 { ...@@ -260,7 +260,7 @@ struct sm89_int8_config_M32 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>, InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -292,7 +292,7 @@ struct sm89_int8_config_M16 { ...@@ -292,7 +292,7 @@ struct sm89_int8_config_M16 {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>; using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>, InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
...@@ -300,7 +300,7 @@ struct sm89_int8_config_M16 { ...@@ -300,7 +300,7 @@ struct sm89_int8_config_M16 {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>; using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return vllm::fallback_cutlass_gemm_caller< return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape, InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>, InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment