Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
\ No newline at end of file
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
struct cutlass_3x_gemm_fp8_blockwise {
using GroupSizeM = Int<GroupSizeM_>;
using GroupSizeN = Int<GroupSizeN_>;
using GroupSizeK = Int<GroupSizeK_>;
using TileSizeM = Int<TileSizeM_>;
static_assert(TileSizeM_ % GroupSizeM_ == 0,
"TileSizeM must be a multiple of GroupSizeM");
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementD = OutType;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementC = void;
using StrideC = StrideD;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementBlockScale = float;
using ElementCompute = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
using KernelSchedule = cutlass::gemm::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
GroupSizeM_>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
ElementD, StrideD, AlignmentD, EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
auto prob_shape = c3x::get_problem_shape(a, b);
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
k = get<2>(prob_shape);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto is_contiguous_vector = [](const torch::Tensor& t) {
auto t_sizes = t.sizes();
return t.is_contiguous() &&
(t.dim() == 1 ||
(t.dim() == 2 &&
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
"a_scales must be M major");
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
"b_scales must be K major");
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
cutlass_gemm_caller_blockwise<
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
b_scales);
}
} // namespace vllm
\ No newline at end of file
#pragma once
#include <torch/all.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#pragma once #pragma once
#include "scaled_mm_c3x.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/** /**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
...@@ -9,6 +10,8 @@ ...@@ -9,6 +10,8 @@
namespace vllm { namespace vllm {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default { struct sm90_fp8_config_default {
...@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, ...@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
} }
} }
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace vllm } // namespace vllm
\ No newline at end of file
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#pragma once #pragma once
#include "scaled_mm_c3x.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/** /**
* This file defines Gemm kernel configurations for SM90 (int8) based on the * This file defines Gemm kernel configurations for SM90 (int8) based on the
...@@ -9,6 +10,8 @@ ...@@ -9,6 +10,8 @@
namespace vllm { namespace vllm {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default { struct sm90_int8_config_default {
...@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, ...@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
} }
} }
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
...@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
...@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
...@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
...@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
...@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #include "core/math.hpp"
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using namespace vllm;
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later. NVIDIA GPUs with sm90a (Hopper) or later.
*/ */
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(), using GroupShape = std::array<int64_t, 2>;
"currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>( int M = a.size(0), N = b.size(1), K = a.size(1);
c, a, b, a_scales, b_scales, *bias);
GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
if (s.numel() == 1) return {M, K}; // tensor-wise
if (s.dim() == 2)
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
}();
GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
if (s.numel() == 1) return {K, N}; // tensor-wise
if (s.dim() == 2)
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
}();
if ((a_scale_group_shape == GroupShape{M, K} ||
a_scale_group_shape == GroupShape{1, K}) &&
(b_scale_group_shape == GroupShape{K, N} ||
b_scale_group_shape == GroupShape{K, 1})) {
// "standard per-tensor/per-token/per-channel" scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else if (a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128}) {
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
} else { } else {
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>( TORCH_CHECK(false,
c, a, b, a_scales, b_scales); "Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
"a_scale_group_shape must be [1, 128], got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128], got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
} }
} }
...@@ -70,18 +73,11 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, ...@@ -70,18 +73,11 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>( azp, bias);
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
} }
#endif
...@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, ...@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
...@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, ...@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
...@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
...@@ -81,23 +81,33 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { ...@@ -81,23 +81,33 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false; return false;
} }
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1)); b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment // Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
...@@ -148,8 +158,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, ...@@ -148,8 +158,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
...@@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, ...@@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"No compiled cutlass_scaled_mm_azp for a compute capability less than " "No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ", "CUDA device capability: ",
version_num); version_num);
} }
\ No newline at end of file
...@@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) { ...@@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0x64086408; const int SUB = 0x64086408;
...@@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) { ...@@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4; q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b; typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t MUL = 0x3F803F80;
...@@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) { ...@@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
const int SUB = 0x64006400; const int SUB = 0x64006400;
const int MUL = 0x2c002c00; const int MUL = 0x2c002c00;
...@@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) { ...@@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4; q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b; typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t MUL = 0x3F803F80;
...@@ -834,6 +834,7 @@ __global__ void Marlin( ...@@ -834,6 +834,7 @@ __global__ void Marlin(
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
...@@ -932,11 +933,11 @@ __global__ void Marlin( ...@@ -932,11 +933,11 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) { if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} else { } else {
...@@ -1038,9 +1039,7 @@ __global__ void Marlin( ...@@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case // No act-order case
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage = int4* sh_s_stage = sh_s + s_sh_stage * pipe;
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else { } else {
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
...@@ -1339,15 +1338,15 @@ __global__ void Marlin( ...@@ -1339,15 +1338,15 @@ __global__ void Marlin(
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float* c_rd = float* c_rd = reinterpret_cast<float*>(
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]); &sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh[red_sh_wr] = sh_red[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
} }
} }
...@@ -1357,7 +1356,7 @@ __global__ void Marlin( ...@@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i++) { for (int i = 0; i < 4 * 2; i++) {
float* c_rd = float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
...@@ -1397,7 +1396,7 @@ __global__ void Marlin( ...@@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred( cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i], &sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)], c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
...@@ -1410,7 +1409,7 @@ __global__ void Marlin( ...@@ -1410,7 +1409,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) { if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>( reinterpret_cast<float*>(
...@@ -1461,10 +1460,10 @@ __global__ void Marlin( ...@@ -1461,10 +1460,10 @@ __global__ void Marlin(
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c); float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll #pragma unroll
for (int k = 0; k < th_size; k++) { for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] = sh_red[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]); float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll #pragma unroll
for (int f = 0; f < 4; f++) { for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
...@@ -1515,7 +1514,7 @@ __global__ void Marlin( ...@@ -1515,7 +1514,7 @@ __global__ void Marlin(
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
((scalar_t2*)sh)[idx] = res; ((scalar_t2*)sh_red)[idx] = res;
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
...@@ -1543,7 +1542,7 @@ __global__ void Marlin( ...@@ -1543,7 +1542,7 @@ __global__ void Marlin(
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) { i++) {
if (c_gl_wr < c_gl_wr_end) { if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd]; C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta; c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta; c_sh_rd += c_sh_rd_delta;
} }
...@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, ...@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
float pipe_size = (a_size + b_size) * pipe_stages; float pipe_size = (a_size + b_size) * pipe_stages;
float reduce_size = max(th_config.num_threads * 32 * 4,
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
} }
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
......
...@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { ...@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
static inline std::optional<at::ScalarType> maybe_scalartype( static inline std::optional<at::ScalarType> maybe_scalartype(
c10::optional<at::Tensor> const& t) { std::optional<at::Tensor> const& t) {
if (!t) { if (!t) {
return std::nullopt; return std::nullopt;
} else { } else {
...@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate< ...@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>; Sch>;
{% for sch in schs %} {% for sch in schs %}
...@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) { ...@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert {{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator {{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B); >(args.B);
} }
{%- endfor %} {%- endfor %}
...@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) { ...@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete }; // namespace machete
""" """
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
...@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: ...@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig # mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = { kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_", "TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK", "StreamKScheduler": "streamK",
} }
......
...@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder< ...@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType, KernelScheduleType,
cute::enable_if_t<( cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> || KernelTmaWarpSpecializedCooperative>)>> {
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = machete::MacheteCollectiveMma< using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>; StageCountType, KernelScheduleType>;
}; };
}; // namespace cutlass::gemm::collective }; // namespace cutlass::gemm::collective
\ No newline at end of file
...@@ -66,13 +66,11 @@ struct MacheteCollectiveMma { ...@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using Schedule = KernelScheduleType; using Schedule = KernelScheduleType;
static_assert( static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> || cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> || cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> || cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> || cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule, cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
"KernelSchedule must be one of the warp specialized policies"); "KernelSchedule must be one of the warp specialized policies");
public: public:
...@@ -113,8 +111,7 @@ struct MacheteCollectiveMma { ...@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma // For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup) // instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t< using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType, cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>; Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma( using TiledMma = decltype(cute::make_tiled_mma(
...@@ -275,6 +272,10 @@ struct MacheteCollectiveMma { ...@@ -275,6 +272,10 @@ struct MacheteCollectiveMma {
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>; using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
// One threads per CTA are producers (1 for operand tile)
static constexpr int NumProducerThreadEvents = 1;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
shape<1>(SmemLayoutAtomScale{}))); shape<1>(SmemLayoutAtomScale{})));
......
...@@ -183,11 +183,11 @@ struct MacheteKernelTemplate { ...@@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
torch::Tensor const& A, // MxK matrix torch::Tensor const& A, // MxK matrix
torch::Tensor const& B, // KxN prepacked matrix torch::Tensor const& B, // KxN prepacked matrix
torch::Tensor& D, // MxN matrix torch::Tensor& D, // MxN matrix
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
c10::optional<int64_t> maybe_group_size, std::optional<int64_t> maybe_group_size,
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
{ {
static_assert(!with_group_zeropoints || with_group_scales); static_assert(!with_group_zeropoints || with_group_scales);
......
...@@ -13,23 +13,23 @@ struct MMArgs { ...@@ -13,23 +13,23 @@ struct MMArgs {
torch::Tensor const& A; torch::Tensor const& A;
torch::Tensor const& B; torch::Tensor const& B;
vllm::ScalarType const& b_type; vllm::ScalarType const& b_type;
c10::optional<at::ScalarType> const& maybe_out_type; std::optional<at::ScalarType> const& maybe_out_type;
c10::optional<torch::Tensor> const& maybe_group_scales; std::optional<torch::Tensor> const& maybe_group_scales;
c10::optional<torch::Tensor> const& maybe_group_zeros; std::optional<torch::Tensor> const& maybe_group_zeros;
c10::optional<int64_t> maybe_group_size; std::optional<int64_t> maybe_group_size;
c10::optional<torch::Tensor> const& maybe_channel_scales; std::optional<torch::Tensor> const& maybe_channel_scales;
c10::optional<torch::Tensor> const& maybe_token_scales; std::optional<torch::Tensor> const& maybe_token_scales;
c10::optional<std::string> maybe_schedule; std::optional<std::string> maybe_schedule;
}; };
struct SupportedSchedulesArgs { struct SupportedSchedulesArgs {
at::ScalarType a_type; at::ScalarType a_type;
vllm::ScalarType b_type; vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type; std::optional<at::ScalarType> maybe_group_scales_type;
c10::optional<at::ScalarType> maybe_group_zeros_type; std::optional<at::ScalarType> maybe_group_zeros_type;
c10::optional<at::ScalarType> maybe_channel_scales_type; std::optional<at::ScalarType> maybe_channel_scales_type;
c10::optional<at::ScalarType> maybe_token_scales_type; std::optional<at::ScalarType> maybe_token_scales_type;
c10::optional<at::ScalarType> maybe_out_type; std::optional<at::ScalarType> maybe_out_type;
}; };
torch::Tensor mm_dispatch(MMArgs args); torch::Tensor mm_dispatch(MMArgs args);
......
...@@ -10,7 +10,7 @@ struct PrepackBArgs { ...@@ -10,7 +10,7 @@ struct PrepackBArgs {
torch::Tensor const& B; torch::Tensor const& B;
at::ScalarType a_type; at::ScalarType a_type;
vllm::ScalarType b_type; vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type; std::optional<at::ScalarType> maybe_group_scales_type;
}; };
template <typename PrepackedLayoutB> template <typename PrepackedLayoutB>
......
...@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate { ...@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma // For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup) // instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t< using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>; Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma( using TiledMma = decltype(cute::make_tiled_mma(
...@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate { ...@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
} }
}; };
}; // namespace machete }; // namespace machete
\ No newline at end of file
...@@ -10,11 +10,11 @@ using namespace vllm; ...@@ -10,11 +10,11 @@ using namespace vllm;
std::vector<std::string> supported_schedules( std::vector<std::string> supported_schedules(
at::ScalarType a_type, int64_t b_type_id, at::ScalarType a_type, int64_t b_type_id,
c10::optional<at::ScalarType> maybe_group_scales_type, std::optional<at::ScalarType> maybe_group_scales_type,
c10::optional<at::ScalarType> maybe_group_zeros_type, std::optional<at::ScalarType> maybe_group_zeros_type,
c10::optional<at::ScalarType> maybe_channel_scales_type, std::optional<at::ScalarType> maybe_channel_scales_type,
c10::optional<at::ScalarType> maybe_token_scales_type, std::optional<at::ScalarType> maybe_token_scales_type,
c10::optional<at::ScalarType> maybe_out_type) { std::optional<at::ScalarType> maybe_out_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id); ScalarType const b_type = ScalarType::from_id(b_type_id);
return supported_schedules_dispatch({ return supported_schedules_dispatch({
.a_type = a_type, .a_type = a_type,
...@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules( ...@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
int64_t b_type_id, int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_out_type, std::optional<at::ScalarType> const& maybe_out_type,
c10::optional<torch::Tensor> const& maybe_group_scales, std::optional<torch::Tensor> const& maybe_group_scales,
c10::optional<torch::Tensor> const& maybe_group_zeros, std::optional<torch::Tensor> const& maybe_group_zeros,
c10::optional<int64_t> maybe_group_size, std::optional<int64_t> maybe_group_size,
c10::optional<torch::Tensor> const& maybe_channel_scales, std::optional<torch::Tensor> const& maybe_channel_scales,
c10::optional<torch::Tensor> const& maybe_token_scales, std::optional<torch::Tensor> const& maybe_token_scales,
c10::optional<std::string> maybe_schedule) { std::optional<std::string> maybe_schedule) {
ScalarType const b_type = ScalarType::from_id(b_type_id); ScalarType const b_type = ScalarType::from_id(b_type_id);
return mm_dispatch({.A = A, return mm_dispatch({.A = A,
.B = B, .B = B,
...@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, ...@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
torch::Tensor prepack_B( torch::Tensor prepack_B(
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_group_scales_type) { std::optional<at::ScalarType> const& maybe_group_scales_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id); ScalarType const b_type = ScalarType::from_id(b_type_id);
return prepack_B_dispatch( return prepack_B_dispatch(
{.B = B, {.B = B,
......
...@@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) { ...@@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0x64086408; const int SUB = 0x64086408;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment