Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
...@@ -18,28 +18,34 @@ using ProblemShape = ...@@ -18,28 +18,34 @@ using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>; cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
using ElementAccumulator = float; using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor; using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
template <typename ElementAB_, typename ElementC_, using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule> typename EpilogueSchedule, bool swap_ab_ = false>
struct cutlass_3x_group_gemm { struct cutlass_3x_group_gemm {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = ElementAB_; using ElementAB = ElementAB_;
using ElementC = void; using ElementC = void;
using ElementD = ElementC_; using ElementD = ElementC_;
using ElementAccumulator = float; using ElementAccumulator = float;
using ArchTag = ArchTag_;
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>; using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
using StrideC =
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
static constexpr int AlignmentAB = static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value; 128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
...@@ -50,21 +56,28 @@ struct cutlass_3x_group_gemm { ...@@ -50,21 +56,28 @@ struct cutlass_3x_group_gemm {
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, ElementAccumulator, ElementC,
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage); sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>; static_cast<int>(CEStorageSize)>;
using CollectiveMainloop = using CollectiveMainloop = conditional_t<
swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
Stages, KernelSchedule>::CollectiveOp; Stages, KernelSchedule>::CollectiveOp>;
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>; ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
...@@ -78,12 +91,12 @@ void cutlass_group_gemm_caller( ...@@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) { bool per_act_token, bool per_out_ch) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
...@@ -110,26 +123,47 @@ void cutlass_group_gemm_caller( ...@@ -110,26 +123,47 @@ void cutlass_group_gemm_caller(
problem_sizes.data_ptr()); problem_sizes.data_ptr());
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
typename GemmKernel::MainloopArguments mainloop_args{ typename GemmKernel::MainloopArguments mainloop_args;
static_cast<const ElementAB**>(a_ptrs.data_ptr()), if constexpr (swap_ab) {
static_cast<StrideA*>(a_strides.data_ptr()), mainloop_args = typename GemmKernel::MainloopArguments{
static_cast<const ElementAB**>(b_ptrs.data_ptr()), static_cast<const ElementAB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides.data_ptr())}; static_cast<StrideB*>(b_strides.data_ptr()),
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides.data_ptr())};
} else {
mainloop_args = typename GemmKernel::MainloopArguments{
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides.data_ptr()),
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides.data_ptr())};
}
// Currently, we are only able to do broadcast on either all or none a_scales // Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales // and on either all or none b_scales
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args( Gemm::Epilogue::prepare_args(
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()), swap_ab ? static_cast<const ElementAccumulator**>(
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()), b_scales_ptrs.data_ptr())
per_act_token, per_out_ch), : static_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr()),
swap_ab ? static_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr())
: static_cast<const ElementAccumulator**>(
b_scales_ptrs.data_ptr()),
swap_ab ? per_out_ch : per_act_token,
swap_ab ? per_act_token : per_out_ch),
nullptr, static_cast<StrideC*>(c_strides.data_ptr()), nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())}; static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
typename GemmKernel::Arguments args{ typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
epilogue_args}; epilogue_args, hw_info};
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op; GemmOp gemm_op;
......
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
using namespace cute;
namespace {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1,64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_N8192 {
// N in [8192, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
uint32_t const m = a_tensors.size(0);
uint32_t const n = out_tensors.size(1);
if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
} // namespace
void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
}
...@@ -21,27 +21,49 @@ struct sm90_fp8_config_default { ...@@ -21,27 +21,49 @@ struct sm90_fp8_config_default {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>; using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>; using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
KernelSchedule, EpilogueSchedule>; ClusterShape, KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M16 { struct sm90_fp8_config_M4 {
// M in [1, 16] // M in [1, 4]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>; using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>; using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
KernelSchedule, EpilogueSchedule>; ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in (4, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -55,10 +77,11 @@ struct sm90_fp8_config_K8192 { ...@@ -55,10 +77,11 @@ struct sm90_fp8_config_K8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>; using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>; using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
KernelSchedule, EpilogueSchedule>; ClusterShape, KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -72,10 +95,11 @@ struct sm90_fp8_config_N8192 { ...@@ -72,10 +95,11 @@ struct sm90_fp8_config_N8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>; using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>; using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
KernelSchedule, EpilogueSchedule>; ClusterShape, KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType>
...@@ -95,14 +119,13 @@ void run_cutlass_moe_mm_sm90( ...@@ -95,14 +119,13 @@ void run_cutlass_moe_mm_sm90(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn."); "B tensors must be of type float8_e4m3fn.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmDefault = typename sm90_fp8_config_default< using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
...@@ -111,18 +134,24 @@ void run_cutlass_moe_mm_sm90( ...@@ -111,18 +134,24 @@ void run_cutlass_moe_mm_sm90(
uint32_t const n = out_tensors.size(1); uint32_t const n = out_tensors.size(1);
uint32_t const k = a_tensors.size(1); uint32_t const k = a_tensors.size(1);
if (n >= 8192) { // Use swap_ab for M <= 64 by default to reduce padding
cutlass_group_gemm_caller<Cutlass3xGemmN8192>( if (m <= 4) {
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch); per_out_ch);
} else if (k >= 8192) { } else if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>( cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch); per_out_ch);
} else if (m <= 16) { } else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>( cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch); per_out_ch);
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include <iostream> #include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512; constexpr uint64_t THREADS_PER_EXPERT = 512;
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
constexpr int SWAP_AB_THRESHOLD = 64;
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, template <bool SWAP_AB>
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
int32_t* atomic_buffer, int32_t* atomic_buffer,
...@@ -24,45 +27,58 @@ __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, ...@@ -24,45 +27,58 @@ __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id]; int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = final_occurrences; if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3 + 1] = 2 * n; problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k; problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes2[expert_id * 3] = final_occurrences; problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3 + 1] = k; problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n; problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
} else {
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
} }
} }
__global__ void compute_expert_offsets( __global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* atomic_buffer, const int num_experts) { int32_t* atomic_buffer, const int num_experts, const int topk_length) {
int32_t tot_offset = 0; int32_t tot_offset = 0;
expert_offsets[0] = 0; expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset; atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3]; tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset; expert_offsets[i + 1] = tot_offset;
} }
} }
__global__ void compute_expert_blockscale_offsets( __global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer, int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
const int num_experts) { const int topk_length) {
int32_t tot_offset = 0; int32_t tot_offset = 0;
int32_t tot_offset_round = 0; int32_t tot_offset_round = 0;
expert_offsets[0] = 0; expert_offsets[0] = 0;
blockscale_offsets[0] = 0; blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
atomic_buffer[i] = tot_offset; atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3]; tot_offset += cur_offset;
expert_offsets[i + 1] = tot_offset; expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128; tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round; blockscale_offsets[i + 1] = tot_offset_round;
} }
} }
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids, __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets, const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
...@@ -102,25 +118,39 @@ void get_cutlass_moe_mm_data_caller( ...@@ -102,25 +118,39 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()), if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
static_cast<int32_t*>(problem_sizes1.data_ptr()), compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
}
if (blockscale_offsets.has_value()) { if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()), static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
} else { } else {
compute_expert_offsets<<<1, 1, 0, stream>>>( compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
} }
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
...@@ -160,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, ...@@ -160,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n, static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k); k);
} }
\ No newline at end of file
...@@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90( ...@@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(
#endif #endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
...@@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { ...@@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// and at least SM90 (Hopper) // and at least SM90 (Hopper)
#if defined CUDA_VERSION #if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) { if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080; return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} }
#endif #endif
...@@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { ...@@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
} }
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 // CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// and SM90 (Hopper) // or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION #if defined CUDA_VERSION
if (cuda_device_capability == 90) { if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030; return CUDA_VERSION >= 12030;
} }
#endif #endif
...@@ -234,16 +247,26 @@ void cutlass_moe_mm( ...@@ -234,16 +247,26 @@ void cutlass_moe_mm(
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) { bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, if (version_num >= 90) {
expert_offsets, problem_sizes, a_strides, b_strides, cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
c_strides, per_act_token, per_out_ch); expert_offsets, problem_sizes, a_strides, b_strides,
return; c_strides, per_act_token, per_out_ch);
return;
}
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90"); ". Required capability: 90 or 100");
} }
void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_data(
......
...@@ -30,35 +30,40 @@ ...@@ -30,35 +30,40 @@
#include "cutlass/util/packed_stride.hpp" #include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
using namespace cute; using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template <typename T>
struct KernelTraits;
template <> // Configuration for M in (256, inf)
struct KernelTraits<float> { struct sm100_fp4_config_default {
using MmaTileShape = Shape<_128, _128, _256>; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using ClusterShape = Shape<_1, _1, _1>; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using PerSmTileShape_MNK = Shape<_128, _128, _256>; using TileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<_2, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
}; };
template <> // Configuration for M in (16, 256]
struct KernelTraits<cutlass::half_t> { struct sm100_fp4_config_M256 {
using MmaTileShape = Shape<_256, _256, _256>; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using ClusterShape = Shape<_4, _4, _1>; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using PerSmTileShape_MNK = Shape<_128, _256, _256>; using TileShape = Shape<_256, _128, _256>;
using ClusterShape = Shape<_2, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
}; };
template <> // Configuration for M in [1, 16]
struct KernelTraits<cutlass::bfloat16_t> { struct sm100_fp4_config_M16 {
using MmaTileShape = Shape<_256, _256, _256>; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using ClusterShape = Shape<_4, _4, _1>; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using PerSmTileShape_MNK = Shape<_128, _256, _256>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
}; };
template <typename T> template <typename Config, typename OutType>
struct Fp4GemmSm100 { struct Fp4GemmSm100 {
// A matrix configuration // A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
...@@ -71,21 +76,22 @@ struct Fp4GemmSm100 { ...@@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
static constexpr int AlignmentB = 32; static constexpr int AlignmentB = 32;
// C/D matrix configuration // C/D matrix configuration
using ElementD = T; using ElementD = OutType;
using ElementC = T; using ElementC = OutType;
using LayoutCTag = cutlass::layout::RowMajor; using LayoutCTag = cutlass::layout::RowMajor;
using LayoutDTag = cutlass::layout::RowMajor; using LayoutDTag = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Kernel functional config // Kernel functional config
using ElementAccumulator = float; using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm100; using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
// Kernel Perf config // Use config's tile shapes
using MmaTileShape = typename KernelTraits<T>::MmaTileShape; using MmaTileShape = typename Config::TileShape;
using ClusterShape = typename KernelTraits<T>::ClusterShape; using ClusterShape = typename Config::ClusterShape;
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK; using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
...@@ -119,22 +125,22 @@ struct Fp4GemmSm100 { ...@@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
}; };
template <typename T> template <typename Config>
typename T::Gemm::Arguments args_from_options( typename Config::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) { int64_t M, int64_t N, int64_t K) {
using ElementA = typename T::Gemm::ElementA; using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename T::Gemm::ElementB; using ElementB = typename Config::Gemm::ElementB;
using ElementSFA = cutlass::float_ue4m3_t; using ElementSFA = cutlass::float_ue4m3_t;
using ElementSFB = cutlass::float_ue4m3_t; using ElementSFB = cutlass::float_ue4m3_t;
using ElementD = typename T::Gemm::ElementD; using ElementD = typename Config::Gemm::ElementD;
using ElementCompute = float; using ElementCompute = float;
using StrideA = typename T::StrideA; using StrideA = typename Config::StrideA;
using StrideB = typename T::StrideB; using StrideB = typename Config::StrideB;
using StrideD = typename T::StrideD; using StrideD = typename Config::StrideD;
using Sm100BlkScaledConfig = using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M); int m = static_cast<int>(M);
int n = static_cast<int>(N); int n = static_cast<int>(N);
...@@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options( ...@@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
cute::make_shape(m, n, k, 1)); cute::make_shape(m, n, k, 1));
typename T::Gemm::Arguments arguments{ typename Config::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm, cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1}, {m, n, k, 1},
{// Mainloop arguments {// Mainloop arguments
...@@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options( ...@@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
return arguments; return arguments;
} }
template <typename T> template <typename Config>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) { cudaStream_t stream) {
typename Fp4GemmSm100<T>::Gemm gemm; typename Config::Gemm gemm;
auto arguments = auto arguments =
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k); args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments); size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto const workspace_options = auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options); auto workspace = torch::empty(workspace_size, workspace_options);
...@@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, ...@@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
// m in [1, 16]
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (mp2 <= 256) {
// m in (16, 256]
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
// m in (256, inf)
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
#else #else
template <typename T> template <typename OutType>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
at::Tensor const& A_sf, at::Tensor const& B_sf, torch::Tensor const& B,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, torch::Tensor const& A_sf,
cudaStream_t stream) { torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
TORCH_CHECK(false, TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
...@@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, ...@@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) { if (out_dtype == at::ScalarType::Half) {
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) { } else if (out_dtype == at::ScalarType::BFloat16) {
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
} else if (out_dtype == at::ScalarType::Float) { m, n, k, stream);
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
")");
} }
} }
...@@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1] torch::Tensor const& scale) // [1]
{ {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256; int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1); int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel(); int const num_elems = input.numel();
...@@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
{ {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256; int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1); int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel(); int const num_elems = input.numel();
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <torch/all.h>
#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
return val;
}
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
scale_packed_t* __restrict__ output_s, const int group_size,
const int num_groups, const int groups_per_block, const float eps,
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack =
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack +
row_idx * num_elems_per_pack + pack_idx);
} else {
scale_output = output_s + global_group_id;
}
// shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
scale_element_t y_s_quant = y_s;
if (lane_id == 0) {
*scale_output = y_s_quant;
}
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
const int num_groups = input.numel() / group_size;
TORCH_CHECK(input.numel() % group_size == 0);
TORCH_CHECK(output_s.dim() == 2);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} \
} else { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
}
}));
#undef LAUNCH_KERNEL
}
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}
...@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
// //
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need // The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
// so we need
// to override this for many GEMMs with the following tag. Otherwise, // to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which // torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices. // will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch // This was a bug and PyTorch 2.7 has since fixed this.
// to match exact eager-mode strides. #if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
at::Tag stride_tag = at::Tag::needs_fixed_stride_order; #define stride_tag at::Tag::needs_fixed_stride_order
#else
#define stride_tag
#endif
ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
...@@ -704,6 +708,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -704,6 +708,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, float scale) -> ()"); " Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, Tensor workspace, float "
"scale,"
" int num_kv_splits) -> ()");
// conditionally compiled so impl in source file
// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
// conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor. // Compute NVFP4 block quantized tensor.
ops.def( ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input," "scaled_fp4_quant(Tensor! output, Tensor input,"
...@@ -785,29 +805,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -785,29 +805,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()"); "int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def( ops.def(
"causal_conv1d_update(Tensor! x," "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"Tensor! conv_state," "output_s, "
"Tensor! weight," "int group_size, float eps, float fp8_min, float fp8_max, bool "
"Tensor? bias_," "scale_ue8m0) -> ()");
"bool silu_activation," ops.impl("per_token_group_fp8_quant", torch::kCUDA,
"Tensor? cache_seqlens_," &per_token_group_quant_fp8);
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def( ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
......
...@@ -63,7 +63,7 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly ...@@ -63,7 +63,7 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
ARG PIP_KEYRING_PROVIDER=disabled ARG PIP_KEYRING_PROVIDER=disabled
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER} ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
# Flag enables build-in KV-connector dependency libs into docker images # Flag enables built-in KV-connector dependency libs into docker images
ARG INSTALL_KV_CONNECTORS=false ARG INSTALL_KV_CONNECTORS=false
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
...@@ -207,6 +207,19 @@ ARG SCCACHE_ENDPOINT ...@@ -207,6 +207,19 @@ ARG SCCACHE_ENDPOINT
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0 ARG SCCACHE_S3_NO_CREDENTIALS=0
# Flag to control whether to use pre-built vLLM wheels
ARG VLLM_USE_PRECOMPILED
# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed
ENV VLLM_USE_PRECOMPILED=""
RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \
export VLLM_USE_PRECOMPILED=1 && \
echo "Using precompiled wheels"; \
else \
unset VLLM_USE_PRECOMPILED && \
echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \
fi
# if USE_SCCACHE is set, use sccache to speed up compilation # if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \ --mount=type=bind,source=.git,target=.git \
...@@ -252,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ ...@@ -252,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
#################### EXTENSION Build IMAGE #################### #################### EXTENSION Build IMAGE ####################
#################### DEV IMAGE #################### #################### DEV IMAGE ####################
FROM base as dev FROM base AS dev
ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
...@@ -375,48 +388,33 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist ...@@ -375,48 +388,33 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl # -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl # $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
# Allow specifying a version, Git revision or local .whl file # Install FlashInfer from source
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
ARG FLASHINFER_GIT_REF="v0.2.6.post1" ARG FLASHINFER_GIT_REF="v0.2.8rc1"
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment . /etc/environment
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then git clone --depth 1 --recursive --shallow-submodules \
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use --branch ${FLASHINFER_GIT_REF} \
if [[ "$CUDA_VERSION" == 12.8* ]]; then ${FLASHINFER_GIT_REPO} flashinfer
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} # Exclude CUDA arches for older versions (11.x and 12.0-12.7)
else # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' if [[ "${CUDA_VERSION}" == 11.* ]]; then
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
# Needed to build AOT kernels elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
(cd flashinfer && \ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
python3 -m flashinfer.aot && \ else
uv pip install --system --no-build-isolation . \ # CUDA 12.8+ supports 10.0a and 12.0
) FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
rm -rf flashinfer fi
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
# Default arches (skipping 10.0a and 12.0 since these need 12.8) # Needed to build AOT kernels
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. pushd flashinfer
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
if [[ "${CUDA_VERSION}" == 11.* ]]; then
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
fi
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
git clone --depth 1 --recursive --shallow-submodules \
--branch v0.2.6.post1 \
https://github.com/flashinfer-ai/flashinfer.git flashinfer
pushd flashinfer
python3 -m flashinfer.aot python3 -m flashinfer.aot
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \ TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
uv pip install --system --no-build-isolation . uv pip install --system --no-build-isolation .
popd popd
rm -rf flashinfer
rm -rf flashinfer
fi \
fi
BASH BASH
COPY examples examples COPY examples examples
COPY benchmarks benchmarks COPY benchmarks benchmarks
...@@ -508,10 +506,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -508,10 +506,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/kv_connectors.txt; \ uv pip install --system -r requirements/kv_connectors.txt; \
fi; \ fi; \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ BITSANDBYTES_VERSION="0.42.0"; \
else \ else \
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ BITSANDBYTES_VERSION="0.46.1"; \
fi fi; \
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]
ENV VLLM_USAGE_SOURCE production-docker-image ENV VLLM_USAGE_SOURCE production-docker-image
......
...@@ -95,7 +95,7 @@ WORKDIR /workspace/vllm ...@@ -95,7 +95,7 @@ WORKDIR /workspace/vllm
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \ cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
......
FROM vault.habana.ai/gaudi-docker/1.20.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
RUN pip install -v -r requirements/hpu.txt
ENV no_proxy=localhost,127.0.0.1
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa" ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6487649" ARG AITER_BRANCH="916bf3c"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base FROM ${BASE_IMAGE} AS base
......
ARG NIGHTLY_DATE="20250124" ARG NIGHTLY_DATE="20250714"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE FROM $BASE_IMAGE
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
......
...@@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai ...@@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server # install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer 'modelscope!=1.15.0' pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
ENV VLLM_USAGE_SOURCE production-docker-image \ ENV VLLM_USAGE_SOURCE production-docker-image \
TRITON_XPU_PROFILE 1 TRITON_XPU_PROFILE 1
......
...@@ -55,6 +55,7 @@ nav: ...@@ -55,6 +55,7 @@ nav:
- contributing/model/registration.md - contributing/model/registration.md
- contributing/model/tests.md - contributing/model/tests.md
- contributing/model/multimodal.md - contributing/model/multimodal.md
- CI: contributing/ci
- Design Documents: - Design Documents:
- V0: design - V0: design
- V1: design/v1 - V1: design/v1
......
...@@ -36,7 +36,7 @@ vLLM is flexible and easy to use with: ...@@ -36,7 +36,7 @@ vLLM is flexible and easy to use with:
- Seamless integration with popular HuggingFace models - Seamless integration with popular HuggingFace models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism and pipeline parallelism support for distributed inference - Tensor, pipeline, data and expert parallelism support for distributed inference
- Streaming outputs - Streaming outputs
- OpenAI-compatible API server - OpenAI-compatible API server
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
...@@ -48,4 +48,4 @@ For more information, check out the following: ...@@ -48,4 +48,4 @@ For more information, check out the following:
- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention) - [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention)
- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023) - [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023)
- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al. - [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al.
- [vLLM Meetups][meetups] - [vLLM Meetups](community/meetups.md)
...@@ -8,14 +8,12 @@ API documentation for vLLM's configuration classes. ...@@ -8,14 +8,12 @@ API documentation for vLLM's configuration classes.
- [vllm.config.ModelConfig][] - [vllm.config.ModelConfig][]
- [vllm.config.CacheConfig][] - [vllm.config.CacheConfig][]
- [vllm.config.TokenizerPoolConfig][]
- [vllm.config.LoadConfig][] - [vllm.config.LoadConfig][]
- [vllm.config.ParallelConfig][] - [vllm.config.ParallelConfig][]
- [vllm.config.SchedulerConfig][] - [vllm.config.SchedulerConfig][]
- [vllm.config.DeviceConfig][] - [vllm.config.DeviceConfig][]
- [vllm.config.SpeculativeConfig][] - [vllm.config.SpeculativeConfig][]
- [vllm.config.LoRAConfig][] - [vllm.config.LoRAConfig][]
- [vllm.config.PromptAdapterConfig][]
- [vllm.config.MultiModalConfig][] - [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][] - [vllm.config.PoolerConfig][]
- [vllm.config.DecodingConfig][] - [vllm.config.DecodingConfig][]
...@@ -64,7 +62,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult ...@@ -64,7 +62,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult
Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models] Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models]
via the `multi_modal_data` field in [vllm.inputs.PromptType][]. via the `multi_modal_data` field in [vllm.inputs.PromptType][].
Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal]. Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md).
- [vllm.multimodal.MULTIMODAL_REGISTRY][] - [vllm.multimodal.MULTIMODAL_REGISTRY][]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment