Unverified Commit 7c080dd3 authored by mikaylagawarecki's avatar mikaylagawarecki Committed by GitHub
Browse files

[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)


Signed-off-by: default avatarMikayla Gawarecki <mikaylagawarecki@gmail.com>
parent 0dd25a44
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
...@@ -35,18 +34,19 @@ ...@@ -35,18 +34,19 @@
using namespace cute; using namespace cute;
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 { struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
...@@ -109,12 +109,13 @@ struct Fp4GemmSm120 { ...@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
}; };
template <typename Gemm> template <typename Gemm>
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
at::Tensor const& B, torch::stable::Tensor const& A,
at::Tensor const& A_sf, torch::stable::Tensor const& B,
at::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int M, torch::stable::Tensor const& B_sf,
int N, int K) { torch::stable::Tensor const& alpha,
int M, int N, int K) {
using ElementA = typename Gemm::ElementA; using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB; using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
...@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, ...@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
} }
template <typename Gemm> template <typename Gemm>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
at::Tensor const& A_sf, at::Tensor const& B_sf, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int M, int N, int K, torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) { cudaStream_t stream) {
Gemm gemm; Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K); auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments); size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.can_implement(arguments));
...@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, ...@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int m, int n, torch::stable::Tensor const& B_sf,
int k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
...@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, ...@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
} }
} }
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int m, int n, torch::stable::Tensor const& B_sf,
int k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
...@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, ...@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
} }
} }
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha) { torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
...@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, ...@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix"); STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix"); STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1], STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.sizes()[0], "x", "a and b shapes cannot be multiplied (", A.size(0), "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0]; auto const m = A.size(0);
auto const n = B.sizes()[0]; auto const n = B.size(0);
auto const k = A.sizes()[1] * 2; auto const k = A.size(1) * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, "."); "), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128); int rounded_m = round_up(m, 128);
...@@ -248,37 +254,38 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, ...@@ -248,37 +254,38 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
// integer. // integer.
int rounded_k = round_up(k / 16, 4); int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
"x", B_sf.sizes()[1], ")"); B_sf.size(1), ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m, "scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.sizes()[1], ")"); A_sf.size(1), ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n, "scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.sizes()[1], ")"); B_sf.size(1), ")");
auto out_dtype = D.dtype(); auto out_dtype = D.scalar_type();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const torch::stable::accelerator::DeviceGuard device_guard(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::BFloat16) { if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else if (out_dtype == at::ScalarType::Half) { } else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")"); out_dtype, ")");
} }
#else #else
TORCH_CHECK(false, STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <utility> #include <utility>
#include "../../cuda_vec_utils.cuh" #include "cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ #if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090 CUDA_VERSION >= 12090
......
...@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { ...@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
ops.def( ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool"); "bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif #endif
} }
...@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ...@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets)); TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data", ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data)); TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif #endif
} }
...@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { ...@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
TORCH_BOX(&cutlass_group_gemm_supported)); TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8", ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8)); TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif #endif
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h> #include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h> #include <torch/headeronly/util/shim_utils.h>
......
...@@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); ...@@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant( void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E) const at::Tensor& counts, // (E)
...@@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, ...@@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type); int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale, torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp); std::optional<torch::Tensor> const& azp);
......
...@@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#ifndef USE_ROCM
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
#endif
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
...@@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? qzeros_or_none, bool inplace) -> Tensor"); "Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif #endif
// Dequantization for GGML. // Dequantization for GGML.
...@@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def( ops.def(
"mxfp8_experts_quant(" "mxfp8_experts_quant("
...@@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> int"); "-> int");
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment