"examples/vscode:/vscode.git/clone" did not exist on "dc14ce05faecb448f67229037041d2aadc8527ed"
Unverified Commit ab1a6a43 authored by mikaylagawarecki's avatar mikaylagawarecki Committed by GitHub
Browse files

[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)


Signed-off-by: default avatarMikayla Gawarecki <mikaylagawarecki@gmail.com>
parent b5e60825
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm90_fp8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales, return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias); b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
...@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 { ...@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
}; };
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
...@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, ...@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias, template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm90_fp8_dispatch(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType, typename sm90_fp8_config_default<InType, OutType,
...@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, ...@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
} }
template <bool EnableBias, typename... EpilogueArgs> template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>( cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>( cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
......
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm90_int8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
...@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall { ...@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType, typename sm90_int8_config_default<InType, OutType,
...@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, ...@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
......
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <c10/cuda/CUDAStream.h> #include <torch/headeronly/core/ScalarType.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h" #include "cutlass/bfloat16.h"
#include "cutlass/float8.h" #include "cutlass/float8.h"
...@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts( ...@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
} }
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \ get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \ static_cast<int64_t*>(expert_offsets.data_ptr()), \
...@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts( ...@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
namespace { namespace {
void run_get_group_gemm_starts( void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::Tensor const& b_scales) { torch::stable::Tensor const& a_scales,
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); torch::stable::Tensor const& b_scales) {
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a_tensors.scalar_type() ==
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float8_e4m3fn);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
// expect int64_t to avoid overflow during offset calculations // expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1; bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts; bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) { if (false) {
} }
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) __CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half) cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else { else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
} }
} // namespace } // namespace
\ No newline at end of file
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/device/gemm_universal_adapter.h"
#include <torch/csrc/stable/ops.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#include "get_group_starts.cuh" #include "get_group_starts.cuh"
...@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm { ...@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_group_gemm_caller( void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
...@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller( ...@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = get_current_cuda_stream(a_tensors.get_device_index());
auto options_int = auto device = a_tensors.device();
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor a_ptrs = torch::stable::empty(
torch::Tensor b_ptrs = torch::empty(num_experts, options_int); {num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor b_ptrs = torch::stable::empty(
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); {num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor out_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
...@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller( ...@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())}; static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index(); int device_id = a_tensors.get_device_index();
static const cutlass::KernelHardwareInfo hw_info{ static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)}; device_id)};
...@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller( ...@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, device);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h> #include "libtorch_stable/torch_utils.h"
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh" #include "grouped_mm_c3x.cuh"
...@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 { ...@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100( void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); torch::stable::Tensor const& a_strides,
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); torch::stable::Tensor const& b_strides,
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
"A tensors must be of type float8_e4m3fn."); STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default< using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
...@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100( ...@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
} }
} // namespace } // namespace
void dispatch_moe_mm_sm100( void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
if (out_tensors.dtype() == torch::kBFloat16) { torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>( run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
...@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100( ...@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
} }
} }
void cutlass_moe_mm_sm100( void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch); c_strides, per_act_token, per_out_ch);
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h> #include "libtorch_stable/torch_utils.h"
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh" #include "grouped_mm_c3x.cuh"
...@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 { ...@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm90( void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); torch::stable::Tensor const& a_strides,
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); torch::stable::Tensor const& b_strides,
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
"A tensors must be of type float8_e4m3fn."); STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
...@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90( ...@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
} }
} }
void dispatch_moe_mm_sm90( void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
if (out_tensors.dtype() == torch::kBFloat16) { torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
...@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90( ...@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
} // namespace } // namespace
void cutlass_moe_mm_sm90( void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch); c_strides, per_act_token, per_out_ch);
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h> #include "libtorch_stable/torch_utils.h"
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include "dispatch_utils.h" #include "libtorch_stable/dispatch_utils.h"
#include <iostream> #include <iostream>
...@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, ...@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
} }
namespace { namespace {
inline void launch_compute_problem_sizes( inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::stable::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer, torch::stable::Tensor& problem_sizes2,
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream, torch::stable::Tensor& atomic_buffer,
const bool swap_ab, const bool is_gated) { int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab,
const bool is_gated) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>(); auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>(); auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>(); auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>(); auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>( compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n), static_cast<int>(topk_ids.numel()), static_cast<int>(n),
...@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets( ...@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
} }
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset, const torch::stable::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& problem_sizes1,
const int64_t n, const int64_t k, const bool swap_ab) { torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
TORCH_CHECK(expert_first_token_offset.is_cuda(), const bool swap_ab) {
"expert_first_token_offset must be a CUDA tensor"); STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64, "expert_first_token_offset must be a CUDA tensor");
"expert_first_token_offset must be int64"); STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
torch::headeronly::ScalarType::Long,
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(), "expert_first_token_offset must be int64");
"problem_sizes must be CUDA tensors");
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 && STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
problem_sizes2.dtype() == torch::kInt32, "problem_sizes must be CUDA tensors");
"problem_sizes must be int32"); STD_TORCH_CHECK(
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(), problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
"problem_sizes must be contiguous"); problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2, "problem_sizes must be int32");
"problem_sizes must be 2D tensors"); STD_TORCH_CHECK(
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3, problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes second dim must be 3"); "problem_sizes must be contiguous");
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(), STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes1 and problem_sizes2 must have same shape"); "problem_sizes must be 2D tensors");
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
problem_sizes1.size(1) == problem_sizes2.size(1),
"problem_sizes1 and problem_sizes2 must have same shape");
int64_t const num_experts64 = problem_sizes1.size(0); int64_t const num_experts64 = problem_sizes1.size(0);
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1, STD_TORCH_CHECK(
"expert_first_token_offset must have num_experts + 1 elements"); expert_first_token_offset.numel() == num_experts64 + 1,
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32"); "expert_first_token_offset must have num_experts + 1 elements");
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32"); STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
"n and k must fit in int32");
int const num_experts = static_cast<int>(num_experts64); int const num_experts = static_cast<int>(num_experts64);
auto stream = at::cuda::getCurrentCUDAStream( auto stream =
expert_first_token_offset.device().index()); get_current_cuda_stream(expert_first_token_offset.get_device_index());
int const threads = (num_experts < 256) ? num_experts : 256; int const threads = (num_experts < 256) ? num_experts : 256;
int const blocks = (num_experts + threads - 1) / threads; int const blocks = (num_experts + threads - 1) / threads;
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>(); auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>(); auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>(); auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes_from_expert_offsets<SwapAB> compute_problem_sizes_from_expert_offsets<SwapAB>
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr, <<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
num_experts, static_cast<int>(n), num_experts, static_cast<int>(n),
...@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( ...@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
} }
void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::stable::Tensor& topk_ids,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& expert_offsets,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::stable::Tensor& problem_sizes1,
const int64_t num_experts, const int64_t n, const int64_t k, torch::stable::Tensor& problem_sizes2,
const std::optional<torch::Tensor>& blockscale_offsets, torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) { const bool is_gated) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto device = topk_ids.device();
auto options_int32 = auto stream = get_current_cuda_stream(device.index());
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
...@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data( ...@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
} }
void get_cutlass_batched_moe_mm_data_caller( void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::stable::Tensor& expert_offsets,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n, const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) { const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) { if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>( compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
...@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller( ...@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n, static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k); k);
} }
} }
\ No newline at end of file
#include <stddef.h> #include <stddef.h>
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
...@@ -8,7 +9,7 @@ ...@@ -8,7 +9,7 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" #include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm; using namespace vllm;
...@@ -19,32 +20,37 @@ using namespace vllm; ...@@ -19,32 +20,37 @@ using namespace vllm;
template <template <typename, typename> typename Epilogue, template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>( return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
...@@ -53,15 +59,16 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -53,15 +59,16 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
} }
} }
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) { STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
...@@ -74,32 +81,37 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -74,32 +81,37 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
template <template <typename, typename> typename Epilogue, template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>( return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
...@@ -108,15 +120,16 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -108,15 +120,16 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
} }
} }
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm80(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) { STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
...@@ -129,31 +142,34 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -129,31 +142,34 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
template <template <typename, typename> typename Epilogue, template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) { if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else { } else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
...@@ -161,16 +177,20 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -161,16 +177,20 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
} }
} }
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
...@@ -179,15 +199,16 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -179,15 +199,16 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
} }
} }
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm89(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) { STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
......
#pragma once #pragma once
#include <stddef.h> #include <stddef.h>
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
...@@ -95,8 +96,9 @@ struct cutlass_2x_gemm { ...@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
}; };
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, inline void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
...@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, ...@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op; typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = auto device = a.device();
torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace =
auto workspace = torch::empty(workspace_size, workspace_options); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto stream = get_current_cuda_stream(device.index());
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
...@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, ...@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs> template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(torch::Tensor& out, inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the // In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use // shared memory requirements of the Gemm. In such cases, use
...@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out, ...@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
return cutlass_gemm_caller<Gemm>(out, a, b, return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...); std::forward<EpilogueArgs>(args)...);
} else { } else {
TORCH_CHECK(fallback_gemm_shared_mem_size <= STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in); max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>( return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
/** /**
...@@ -70,13 +72,13 @@ struct sm75_config_M32 { ...@@ -70,13 +72,13 @@ struct sm75_config_M32 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault = using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
/** /**
...@@ -72,13 +74,13 @@ struct sm80_config_M16 { ...@@ -72,13 +74,13 @@ struct sm80_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault = using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h" #include "cutlass/float8.h"
...@@ -34,10 +36,12 @@ struct sm89_fp8_config_default { ...@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 { ...@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 { ...@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 { ...@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 { ...@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 { ...@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType, typename sm89_fp8_fallback_gemm<InType, OutType,
...@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 { ...@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
/** /**
...@@ -32,10 +34,11 @@ struct sm89_int8_config_default { ...@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 { ...@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 { ...@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 { ...@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 { ...@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 { ...@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a, static void dispatch(torch::stable::Tensor& out,
torch::Tensor const& b, EpilogueArgs&&... args) { torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm = using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType, typename sm89_int8_fallback_gemm<InType, OutType,
...@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 { ...@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>()); static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(b.dtype() == torch::kInt8); STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm100_fp8, vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100 nullptr, // int8 not supported on SM100
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm120_fp8, vllm::cutlass_scaled_mm_sm120_fp8,
nullptr, // int8 not supported on SM120 nullptr, // int8 not supported on SM120
......
...@@ -8,26 +8,28 @@ ...@@ -8,26 +8,28 @@
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 #if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8, vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8, vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8); vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
} }
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) { STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj, vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias); azp, bias);
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h> #include <torch/csrc/stable/tensor.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp" #include "libtorch_stable/torch_utils.h"
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a, #include "cutlass_extensions/common.hpp"
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias); torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 #if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias); torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif #endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90( void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch); torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif #endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100( void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch); torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif #endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias); torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif #endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias); torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif #endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \ (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120) (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::stable::Tensor& topk_ids,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& expert_offsets,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::stable::Tensor& problem_sizes1,
const int64_t num_experts, const int64_t n, const int64_t k, torch::stable::Tensor& problem_sizes2,
const std::optional<torch::Tensor>& blockscale_offsets, torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated); const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset, const torch::stable::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& problem_sizes1,
const int64_t n, const int64_t k, const bool swap_ab); torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller( void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::stable::Tensor& expert_offsets,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n, const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k); const int64_t k);
#endif #endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(
torch::Tensor const& b, torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias);
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a, torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& b_scales, std::optional<torch::stable::Tensor> const& azp,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& bias);
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias); void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& a_scales, std::optional<torch::stable::Tensor> const& azp,
torch::Tensor const& b_scales, std::optional<torch::stable::Tensor> const& bias);
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 #if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(
torch::Tensor const& b, torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias);
std::optional<torch::Tensor> const& bias);
#endif #endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
...@@ -171,27 +188,29 @@ bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { ...@@ -171,27 +188,29 @@ bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
return false; return false;
} }
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1)); b.size(1) == c.size(1));
// Check for strides and alignment // Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) { if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1); bias->dim() == 1);
} }
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
...@@ -237,20 +256,24 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -237,20 +256,24 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_scaled_mm for a compute capability less than " "No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ", "CUDA device capability: ",
version_num); version_num);
} }
void cutlass_moe_mm( void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::stable::Tensor const& b_tensors,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::stable::Tensor const& a_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::stable::Tensor const& b_scales,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::stable::Tensor const& expert_offsets,
bool per_act_token, bool per_out_ch) { torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) { if (version_num >= 100 && version_num < 110) {
...@@ -268,18 +291,21 @@ void cutlass_moe_mm( ...@@ -268,18 +291,21 @@ void cutlass_moe_mm(
return; return;
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100"); ". Required capability: 90 or 100");
} }
void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::stable::Tensor& topk_ids,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& expert_offsets,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::stable::Tensor& problem_sizes1,
const int64_t num_experts, const int64_t n, const int64_t k, torch::stable::Tensor& problem_sizes2,
const std::optional<torch::Tensor>& blockscale_offsets, torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) { const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
...@@ -293,7 +319,7 @@ void get_cutlass_moe_mm_data( ...@@ -293,7 +319,7 @@ void get_cutlass_moe_mm_data(
blockscale_offsets, is_gated); blockscale_offsets, is_gated);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ", "CUDA device capability: ",
...@@ -301,9 +327,10 @@ void get_cutlass_moe_mm_data( ...@@ -301,9 +327,10 @@ void get_cutlass_moe_mm_data(
} }
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset, const torch::stable::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::stable::Tensor& problem_sizes1,
const int64_t n, const int64_t k, const bool swap_ab) { torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
...@@ -312,20 +339,20 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( ...@@ -312,20 +339,20 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab); expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: " "No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ", "no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120"); version_num, ". Required capability: 90, 100, or 120");
} }
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_batched_moe_mm_data(
torch::Tensor& problem_sizes1, torch::stable::Tensor& expert_offsets,
torch::Tensor& problem_sizes2, torch::stable::Tensor& problem_sizes1,
const torch::Tensor& expert_num_tokens, torch::stable::Tensor& problem_sizes2,
const int64_t num_local_experts, const torch::stable::Tensor& expert_num_tokens,
const int64_t padded_m, const int64_t n, const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) { const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
...@@ -337,52 +364,56 @@ void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets, ...@@ -337,52 +364,56 @@ void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
num_local_experts, padded_m, n, k); num_local_experts, padded_m, n, k);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, STD_TORCH_CHECK_NOT_IMPLEMENTED(
"No compiled get_cutlass_batched_moe_mm_data: no " false,
"cutlass_scaled_mm kernel " "No compiled get_cutlass_batched_moe_mm_data: no "
"for CUDA device capability: ", "cutlass_scaled_mm kernel "
version_num, "for CUDA device capability: ",
". Required capability: 90, 100, or 120"); version_num, ". Required capability: 90, 100, or 120");
} }
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b,
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& azp_adj, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& azp, torch::stable::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1)); b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment // Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d // bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements // bias and azp_adj have n elements, azp has m elements
if (bias) { if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
} }
if (azp) { if (azp) {
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
} }
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types // azp & bias types
TORCH_CHECK(azp_adj.dtype() == torch::kInt32); STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); STD_TORCH_CHECK(!azp ||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(), azp->scalar_type() == torch::headeronly::ScalarType::Int);
"currently bias dtype must match output dtype ", c.dtype()); STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
"currently bias dtype must match output dtype ",
c.scalar_type());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
...@@ -407,12 +438,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, ...@@ -407,12 +438,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
} }
// Turing // Turing
TORCH_CHECK(version_num >= 75); STD_TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than " "No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ", "CUDA device capability: ",
......
...@@ -31,6 +31,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { ...@@ -31,6 +31,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> " "output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()"); "()");
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
#endif #endif
} }
...@@ -46,6 +118,31 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ...@@ -46,6 +118,31 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&per_token_group_quant_8bit_packed)); TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8", ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8)); TORCH_BOX(&per_token_group_quant_int8));
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
#endif
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
#ifndef USE_ROCM
ops.impl("cutlass_scaled_mm_supports_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
ops.impl("cutlass_group_gemm_supported",
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment