Unverified Commit 839b27c6 authored by leoneo's avatar leoneo Committed by GitHub
Browse files

[Kernel]Add streamK for block-quantized CUTLASS kernels (#12978)

parent 34ad27fe
...@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape( ...@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
} }
template <typename GemmKernel> template <typename GemmKernel>
void cutlass_gemm_caller(torch::Device device, void cutlass_gemm_caller(
cute::Shape<int, int, int, int> prob_shape, torch::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args) { typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args}; prob_shape,
mainloop_args,
epilogue_args,
hw_info,
scheduler};
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
......
...@@ -22,8 +22,9 @@ namespace vllm { ...@@ -22,8 +22,9 @@ namespace vllm {
using namespace cute; using namespace cute;
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_, template <typename SchedulerType, typename OutType, int GroupSizeM_,
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>> int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
class ClusterShape = Shape<_1, _2, _1>>
struct cutlass_3x_gemm_fp8_blockwise { struct cutlass_3x_gemm_fp8_blockwise {
using GroupSizeM = Int<GroupSizeM_>; using GroupSizeM = Int<GroupSizeM_>;
using GroupSizeN = Int<GroupSizeN_>; using GroupSizeN = Int<GroupSizeN_>;
...@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>; SchedulerType>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
...@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride}; {}, c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::TileSchedulerArguments scheduler;
static constexpr bool UsesStreamKScheduler =
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
cutlass::gemm::StreamKScheduler>;
if constexpr (UsesStreamKScheduler) {
using DecompositionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
using ReductionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
scheduler.decomposition_mode = DecompositionMode::StreamK;
scheduler.reduction_mode = ReductionMode::Nondeterministic;
}
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args); epilogue_args, scheduler);
} }
template <typename OutType> template <typename OutType>
...@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, ...@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
cutlass_gemm_caller_blockwise< auto k = a.size(1);
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales, auto n = b.size(1);
b_scales);
if (k > 3 * n) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
}
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment