// Adapted from // https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh #pragma once // clang-format will break include orders // clang-format off #include #include #include #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/util/packed_stride.hpp" // clang-format on /** * Helper function for checking CUTLASS errors */ #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ } template void cutlass_gemm_caller( torch::Device device, cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::TileSchedulerArguments scheduler = {}) { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = c10::cuda::current_device(); hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args, hw_info, scheduler}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(device); auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(device.index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); }